tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, Role,
 12};
 13use project::Project;
 14use ui::{IconName, Window};
 15use util::truncate_lines_to_byte_limit;
 16
 17use crate::thread::{MessageId, PromptId, ThreadId};
 18use crate::thread_store::SerializedMessage;
 19
 20#[derive(Debug)]
 21pub struct ToolUse {
 22    pub id: LanguageModelToolUseId,
 23    pub name: SharedString,
 24    pub ui_text: SharedString,
 25    pub status: ToolUseStatus,
 26    pub input: serde_json::Value,
 27    pub icon: ui::IconName,
 28    pub needs_confirmation: bool,
 29}
 30
 31pub struct ToolUseState {
 32    tools: Entity<ToolWorkingSet>,
 33    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 34    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 35    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 36    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 37    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 38}
 39
 40impl ToolUseState {
 41    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 42        Self {
 43            tools,
 44            tool_uses_by_assistant_message: HashMap::default(),
 45            tool_results: HashMap::default(),
 46            pending_tool_uses_by_id: HashMap::default(),
 47            tool_result_cards: HashMap::default(),
 48            tool_use_metadata_by_id: HashMap::default(),
 49        }
 50    }
 51
 52    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 53    ///
 54    /// Accepts a function to filter the tools that should be used to populate the state.
 55    ///
 56    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 57    /// tool cards won't be deserialized
 58    pub fn from_serialized_messages(
 59        tools: Entity<ToolWorkingSet>,
 60        messages: &[SerializedMessage],
 61        project: Entity<Project>,
 62        window: Option<&mut Window>, // None in headless mode
 63        cx: &mut App,
 64    ) -> Self {
 65        let mut this = Self::new(tools);
 66        let mut tool_names_by_id = HashMap::default();
 67        let mut window = window;
 68
 69        for message in messages {
 70            match message.role {
 71                Role::Assistant => {
 72                    if !message.tool_uses.is_empty() {
 73                        let tool_uses = message
 74                            .tool_uses
 75                            .iter()
 76                            .map(|tool_use| LanguageModelToolUse {
 77                                id: tool_use.id.clone(),
 78                                name: tool_use.name.clone().into(),
 79                                raw_input: tool_use.input.to_string(),
 80                                input: tool_use.input.clone(),
 81                                is_input_complete: true,
 82                            })
 83                            .collect::<Vec<_>>();
 84
 85                        tool_names_by_id.extend(
 86                            tool_uses
 87                                .iter()
 88                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 89                        );
 90
 91                        this.tool_uses_by_assistant_message
 92                            .insert(message.id, tool_uses);
 93
 94                        for tool_result in &message.tool_results {
 95                            let tool_use_id = tool_result.tool_use_id.clone();
 96                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 97                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 98                                continue;
 99                            };
100
101                            this.tool_results.insert(
102                                tool_use_id.clone(),
103                                LanguageModelToolResult {
104                                    tool_use_id: tool_use_id.clone(),
105                                    tool_name: tool_use.clone(),
106                                    is_error: tool_result.is_error,
107                                    content: tool_result.content.clone(),
108                                    output: tool_result.output.clone(),
109                                },
110                            );
111
112                            if let Some(window) = &mut window {
113                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
114                                    if let Some(output) = tool_result.output.clone() {
115                                        if let Some(card) = tool.deserialize_card(
116                                            output,
117                                            project.clone(),
118                                            window,
119                                            cx,
120                                        ) {
121                                            this.tool_result_cards.insert(tool_use_id, card);
122                                        }
123                                    }
124                                }
125                            }
126                        }
127                    }
128                }
129                Role::System | Role::User => {}
130            }
131        }
132
133        this
134    }
135
136    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
137        let mut cancelled_tool_uses = Vec::new();
138        self.pending_tool_uses_by_id
139            .retain(|tool_use_id, tool_use| {
140                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
141                    return true;
142                }
143
144                let content = "Tool canceled by user".into();
145                self.tool_results.insert(
146                    tool_use_id.clone(),
147                    LanguageModelToolResult {
148                        tool_use_id: tool_use_id.clone(),
149                        tool_name: tool_use.name.clone(),
150                        content,
151                        output: None,
152                        is_error: true,
153                    },
154                );
155                cancelled_tool_uses.push(tool_use.clone());
156                false
157            });
158        cancelled_tool_uses
159    }
160
161    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
162        self.pending_tool_uses_by_id.values().collect()
163    }
164
165    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
166        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
167            return Vec::new();
168        };
169
170        let mut tool_uses = Vec::new();
171
172        for tool_use in tool_uses_for_message.iter() {
173            let tool_result = self.tool_results.get(&tool_use.id);
174
175            let status = (|| {
176                if let Some(tool_result) = tool_result {
177                    return if tool_result.is_error {
178                        ToolUseStatus::Error(tool_result.content.clone().into())
179                    } else {
180                        ToolUseStatus::Finished(tool_result.content.clone().into())
181                    };
182                }
183
184                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
185                    match pending_tool_use.status {
186                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
187                        PendingToolUseStatus::NeedsConfirmation { .. } => {
188                            ToolUseStatus::NeedsConfirmation
189                        }
190                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
191                        PendingToolUseStatus::Error(ref err) => {
192                            ToolUseStatus::Error(err.clone().into())
193                        }
194                        PendingToolUseStatus::InputStillStreaming => {
195                            ToolUseStatus::InputStillStreaming
196                        }
197                    }
198                } else {
199                    ToolUseStatus::Pending
200                }
201            })();
202
203            let (icon, needs_confirmation) =
204                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
205                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
206                } else {
207                    (IconName::Cog, false)
208                };
209
210            tool_uses.push(ToolUse {
211                id: tool_use.id.clone(),
212                name: tool_use.name.clone().into(),
213                ui_text: self.tool_ui_label(
214                    &tool_use.name,
215                    &tool_use.input,
216                    tool_use.is_input_complete,
217                    cx,
218                ),
219                input: tool_use.input.clone(),
220                status,
221                icon,
222                needs_confirmation,
223            })
224        }
225
226        tool_uses
227    }
228
229    pub fn tool_ui_label(
230        &self,
231        tool_name: &str,
232        input: &serde_json::Value,
233        is_input_complete: bool,
234        cx: &App,
235    ) -> SharedString {
236        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
237            if is_input_complete {
238                tool.ui_text(input).into()
239            } else {
240                tool.still_streaming_ui_text(input).into()
241            }
242        } else {
243            format!("Unknown tool {tool_name:?}").into()
244        }
245    }
246
247    pub fn tool_results_for_message(
248        &self,
249        assistant_message_id: MessageId,
250    ) -> Vec<&LanguageModelToolResult> {
251        let Some(tool_uses) = self
252            .tool_uses_by_assistant_message
253            .get(&assistant_message_id)
254        else {
255            return Vec::new();
256        };
257
258        tool_uses
259            .iter()
260            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
261            .collect()
262    }
263
264    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
265        self.tool_uses_by_assistant_message
266            .get(&assistant_message_id)
267            .map_or(false, |results| !results.is_empty())
268    }
269
270    pub fn tool_result(
271        &self,
272        tool_use_id: &LanguageModelToolUseId,
273    ) -> Option<&LanguageModelToolResult> {
274        self.tool_results.get(tool_use_id)
275    }
276
277    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
278        self.tool_result_cards.get(tool_use_id)
279    }
280
281    pub fn insert_tool_result_card(
282        &mut self,
283        tool_use_id: LanguageModelToolUseId,
284        card: AnyToolCard,
285    ) {
286        self.tool_result_cards.insert(tool_use_id, card);
287    }
288
289    pub fn request_tool_use(
290        &mut self,
291        assistant_message_id: MessageId,
292        tool_use: LanguageModelToolUse,
293        metadata: ToolUseMetadata,
294        cx: &App,
295    ) -> Arc<str> {
296        let tool_uses = self
297            .tool_uses_by_assistant_message
298            .entry(assistant_message_id)
299            .or_default();
300
301        let mut existing_tool_use_found = false;
302
303        for existing_tool_use in tool_uses.iter_mut() {
304            if existing_tool_use.id == tool_use.id {
305                *existing_tool_use = tool_use.clone();
306                existing_tool_use_found = true;
307            }
308        }
309
310        if !existing_tool_use_found {
311            tool_uses.push(tool_use.clone());
312        }
313
314        let status = if tool_use.is_input_complete {
315            self.tool_use_metadata_by_id
316                .insert(tool_use.id.clone(), metadata);
317
318            PendingToolUseStatus::Idle
319        } else {
320            PendingToolUseStatus::InputStillStreaming
321        };
322
323        let ui_text: Arc<str> = self
324            .tool_ui_label(
325                &tool_use.name,
326                &tool_use.input,
327                tool_use.is_input_complete,
328                cx,
329            )
330            .into();
331
332        self.pending_tool_uses_by_id.insert(
333            tool_use.id.clone(),
334            PendingToolUse {
335                assistant_message_id,
336                id: tool_use.id,
337                name: tool_use.name.clone(),
338                ui_text: ui_text.clone(),
339                input: tool_use.input,
340                status,
341            },
342        );
343
344        ui_text
345    }
346
347    pub fn run_pending_tool(
348        &mut self,
349        tool_use_id: LanguageModelToolUseId,
350        ui_text: SharedString,
351        task: Task<()>,
352    ) {
353        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
354            tool_use.ui_text = ui_text.into();
355            tool_use.status = PendingToolUseStatus::Running {
356                _task: task.shared(),
357            };
358        }
359    }
360
361    pub fn confirm_tool_use(
362        &mut self,
363        tool_use_id: LanguageModelToolUseId,
364        ui_text: impl Into<Arc<str>>,
365        input: serde_json::Value,
366        request: Arc<LanguageModelRequest>,
367        tool: Arc<dyn Tool>,
368    ) {
369        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
370            let ui_text = ui_text.into();
371            tool_use.ui_text = ui_text.clone();
372            let confirmation = Confirmation {
373                tool_use_id,
374                input,
375                request,
376                tool,
377                ui_text,
378            };
379            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
380        }
381    }
382
383    pub fn insert_tool_output(
384        &mut self,
385        tool_use_id: LanguageModelToolUseId,
386        tool_name: Arc<str>,
387        output: Result<ToolResultOutput>,
388        configured_model: Option<&ConfiguredModel>,
389    ) -> Option<PendingToolUse> {
390        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
391
392        telemetry::event!(
393            "Agent Tool Finished",
394            model = metadata
395                .as_ref()
396                .map(|metadata| metadata.model.telemetry_id()),
397            model_provider = metadata
398                .as_ref()
399                .map(|metadata| metadata.model.provider_id().to_string()),
400            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
401            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
402            tool_name,
403            success = output.is_ok()
404        );
405
406        match output {
407            Ok(output) => {
408                let tool_result = output.content;
409                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
410
411                // Protect from clearly large output
412                let tool_output_limit = configured_model
413                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
414                    .unwrap_or(usize::MAX);
415
416                let tool_result = if tool_result.len() <= tool_output_limit {
417                    tool_result
418                } else {
419                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
420
421                    format!(
422                        "Tool result too long. The first {} bytes:\n\n{}",
423                        truncated.len(),
424                        truncated
425                    )
426                };
427
428                self.tool_results.insert(
429                    tool_use_id.clone(),
430                    LanguageModelToolResult {
431                        tool_use_id: tool_use_id.clone(),
432                        tool_name,
433                        content: tool_result.into(),
434                        is_error: false,
435                        output: output.output,
436                    },
437                );
438                self.pending_tool_uses_by_id.remove(&tool_use_id)
439            }
440            Err(err) => {
441                self.tool_results.insert(
442                    tool_use_id.clone(),
443                    LanguageModelToolResult {
444                        tool_use_id: tool_use_id.clone(),
445                        tool_name,
446                        content: err.to_string().into(),
447                        is_error: true,
448                        output: None,
449                    },
450                );
451
452                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
453                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
454                }
455
456                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
457            }
458        }
459    }
460
461    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
462        self.tool_uses_by_assistant_message
463            .contains_key(&assistant_message_id)
464    }
465
466    pub fn tool_results(
467        &self,
468        assistant_message_id: MessageId,
469    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
470        self.tool_uses_by_assistant_message
471            .get(&assistant_message_id)
472            .into_iter()
473            .flatten()
474            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
475    }
476}
477
478#[derive(Debug, Clone)]
479pub struct PendingToolUse {
480    pub id: LanguageModelToolUseId,
481    /// The ID of the Assistant message in which the tool use was requested.
482    #[allow(unused)]
483    pub assistant_message_id: MessageId,
484    pub name: Arc<str>,
485    pub ui_text: Arc<str>,
486    pub input: serde_json::Value,
487    pub status: PendingToolUseStatus,
488}
489
490#[derive(Debug, Clone)]
491pub struct Confirmation {
492    pub tool_use_id: LanguageModelToolUseId,
493    pub input: serde_json::Value,
494    pub ui_text: Arc<str>,
495    pub request: Arc<LanguageModelRequest>,
496    pub tool: Arc<dyn Tool>,
497}
498
499#[derive(Debug, Clone)]
500pub enum PendingToolUseStatus {
501    InputStillStreaming,
502    Idle,
503    NeedsConfirmation(Arc<Confirmation>),
504    Running { _task: Shared<Task<()>> },
505    Error(#[allow(unused)] Arc<str>),
506}
507
508impl PendingToolUseStatus {
509    pub fn is_idle(&self) -> bool {
510        matches!(self, PendingToolUseStatus::Idle)
511    }
512
513    pub fn is_error(&self) -> bool {
514        matches!(self, PendingToolUseStatus::Error(_))
515    }
516
517    pub fn needs_confirmation(&self) -> bool {
518        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
519    }
520}
521
522#[derive(Clone)]
523pub struct ToolUseMetadata {
524    pub model: Arc<dyn LanguageModel>,
525    pub thread_id: ThreadId,
526    pub prompt_id: PromptId,
527}