tool_use.rs

  1use crate::{
  2    thread::{MessageId, PromptId, ThreadId},
  3    thread_store::SerializedMessage,
  4};
  5use agent_settings::CompletionMode;
  6use anyhow::Result;
  7use assistant_tool::{
  8    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  9};
 10use collections::HashMap;
 11use futures::{FutureExt as _, future::Shared};
 12use gpui::{App, Entity, SharedString, Task, Window};
 13use icons::IconName;
 14use language_model::{
 15    ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
 16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
 17    LanguageModelToolUseId, Role,
 18};
 19use project::Project;
 20use std::sync::Arc;
 21use util::truncate_lines_to_byte_limit;
 22
 23#[derive(Debug)]
 24pub struct ToolUse {
 25    pub id: LanguageModelToolUseId,
 26    pub name: SharedString,
 27    pub ui_text: SharedString,
 28    pub status: ToolUseStatus,
 29    pub input: serde_json::Value,
 30    pub icon: icons::IconName,
 31    pub needs_confirmation: bool,
 32}
 33
 34pub struct ToolUseState {
 35    tools: Entity<ToolWorkingSet>,
 36    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 37    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 38    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 39    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 40    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 41}
 42
 43impl ToolUseState {
 44    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 45        Self {
 46            tools,
 47            tool_uses_by_assistant_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    ///
 59    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 60    /// tool cards won't be deserialized
 61    pub fn from_serialized_messages(
 62        tools: Entity<ToolWorkingSet>,
 63        messages: &[SerializedMessage],
 64        project: Entity<Project>,
 65        window: Option<&mut Window>, // None in headless mode
 66        cx: &mut App,
 67    ) -> Self {
 68        let mut this = Self::new(tools);
 69        let mut tool_names_by_id = HashMap::default();
 70        let mut window = window;
 71
 72        for message in messages {
 73            match message.role {
 74                Role::Assistant => {
 75                    if !message.tool_uses.is_empty() {
 76                        let tool_uses = message
 77                            .tool_uses
 78                            .iter()
 79                            .map(|tool_use| LanguageModelToolUse {
 80                                id: tool_use.id.clone(),
 81                                name: tool_use.name.clone().into(),
 82                                raw_input: tool_use.input.to_string(),
 83                                input: tool_use.input.clone(),
 84                                is_input_complete: true,
 85                            })
 86                            .collect::<Vec<_>>();
 87
 88                        tool_names_by_id.extend(
 89                            tool_uses
 90                                .iter()
 91                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 92                        );
 93
 94                        this.tool_uses_by_assistant_message
 95                            .insert(message.id, tool_uses);
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            this.tool_results.insert(
105                                tool_use_id.clone(),
106                                LanguageModelToolResult {
107                                    tool_use_id: tool_use_id.clone(),
108                                    tool_name: tool_use.clone(),
109                                    is_error: tool_result.is_error,
110                                    content: tool_result.content.clone(),
111                                    output: tool_result.output.clone(),
112                                },
113                            );
114
115                            if let Some(window) = &mut window {
116                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
117                                    if let Some(output) = tool_result.output.clone() {
118                                        if let Some(card) = tool.deserialize_card(
119                                            output,
120                                            project.clone(),
121                                            window,
122                                            cx,
123                                        ) {
124                                            this.tool_result_cards.insert(tool_use_id, card);
125                                        }
126                                    }
127                                }
128                            }
129                        }
130                    }
131                }
132                Role::System | Role::User => {}
133            }
134        }
135
136        this
137    }
138
139    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
140        let mut cancelled_tool_uses = Vec::new();
141        self.pending_tool_uses_by_id
142            .retain(|tool_use_id, tool_use| {
143                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
144                    return true;
145                }
146
147                let content = "Tool canceled by user".into();
148                self.tool_results.insert(
149                    tool_use_id.clone(),
150                    LanguageModelToolResult {
151                        tool_use_id: tool_use_id.clone(),
152                        tool_name: tool_use.name.clone(),
153                        content,
154                        output: None,
155                        is_error: true,
156                    },
157                );
158                cancelled_tool_uses.push(tool_use.clone());
159                false
160            });
161        cancelled_tool_uses
162    }
163
164    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
165        self.pending_tool_uses_by_id.values().collect()
166    }
167
168    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
169        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
170            return Vec::new();
171        };
172
173        let mut tool_uses = Vec::new();
174
175        for tool_use in tool_uses_for_message.iter() {
176            let tool_result = self.tool_results.get(&tool_use.id);
177
178            let status = (|| {
179                if let Some(tool_result) = tool_result {
180                    let content = tool_result
181                        .content
182                        .to_str()
183                        .map(|str| str.to_owned().into())
184                        .unwrap_or_default();
185
186                    return if tool_result.is_error {
187                        ToolUseStatus::Error(content)
188                    } else {
189                        ToolUseStatus::Finished(content)
190                    };
191                }
192
193                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
194                    match pending_tool_use.status {
195                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
196                        PendingToolUseStatus::NeedsConfirmation { .. } => {
197                            ToolUseStatus::NeedsConfirmation
198                        }
199                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
200                        PendingToolUseStatus::Error(ref err) => {
201                            ToolUseStatus::Error(err.clone().into())
202                        }
203                        PendingToolUseStatus::InputStillStreaming => {
204                            ToolUseStatus::InputStillStreaming
205                        }
206                    }
207                } else {
208                    ToolUseStatus::Pending
209                }
210            })();
211
212            let (icon, needs_confirmation) =
213                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
214                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
215                } else {
216                    (IconName::Cog, false)
217                };
218
219            tool_uses.push(ToolUse {
220                id: tool_use.id.clone(),
221                name: tool_use.name.clone().into(),
222                ui_text: self.tool_ui_label(
223                    &tool_use.name,
224                    &tool_use.input,
225                    tool_use.is_input_complete,
226                    cx,
227                ),
228                input: tool_use.input.clone(),
229                status,
230                icon,
231                needs_confirmation,
232            })
233        }
234
235        tool_uses
236    }
237
238    pub fn tool_ui_label(
239        &self,
240        tool_name: &str,
241        input: &serde_json::Value,
242        is_input_complete: bool,
243        cx: &App,
244    ) -> SharedString {
245        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
246            if is_input_complete {
247                tool.ui_text(input).into()
248            } else {
249                tool.still_streaming_ui_text(input).into()
250            }
251        } else {
252            format!("Unknown tool {tool_name:?}").into()
253        }
254    }
255
256    pub fn tool_results_for_message(
257        &self,
258        assistant_message_id: MessageId,
259    ) -> Vec<&LanguageModelToolResult> {
260        let Some(tool_uses) = self
261            .tool_uses_by_assistant_message
262            .get(&assistant_message_id)
263        else {
264            return Vec::new();
265        };
266
267        tool_uses
268            .iter()
269            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
270            .collect()
271    }
272
273    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
274        self.tool_uses_by_assistant_message
275            .get(&assistant_message_id)
276            .map_or(false, |results| !results.is_empty())
277    }
278
279    pub fn tool_result(
280        &self,
281        tool_use_id: &LanguageModelToolUseId,
282    ) -> Option<&LanguageModelToolResult> {
283        self.tool_results.get(tool_use_id)
284    }
285
286    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
287        self.tool_result_cards.get(tool_use_id)
288    }
289
290    pub fn insert_tool_result_card(
291        &mut self,
292        tool_use_id: LanguageModelToolUseId,
293        card: AnyToolCard,
294    ) {
295        self.tool_result_cards.insert(tool_use_id, card);
296    }
297
298    pub fn request_tool_use(
299        &mut self,
300        assistant_message_id: MessageId,
301        tool_use: LanguageModelToolUse,
302        metadata: ToolUseMetadata,
303        cx: &App,
304    ) -> Arc<str> {
305        let tool_uses = self
306            .tool_uses_by_assistant_message
307            .entry(assistant_message_id)
308            .or_default();
309
310        let mut existing_tool_use_found = false;
311
312        for existing_tool_use in tool_uses.iter_mut() {
313            if existing_tool_use.id == tool_use.id {
314                *existing_tool_use = tool_use.clone();
315                existing_tool_use_found = true;
316            }
317        }
318
319        if !existing_tool_use_found {
320            tool_uses.push(tool_use.clone());
321        }
322
323        let status = if tool_use.is_input_complete {
324            self.tool_use_metadata_by_id
325                .insert(tool_use.id.clone(), metadata);
326
327            PendingToolUseStatus::Idle
328        } else {
329            PendingToolUseStatus::InputStillStreaming
330        };
331
332        let ui_text: Arc<str> = self
333            .tool_ui_label(
334                &tool_use.name,
335                &tool_use.input,
336                tool_use.is_input_complete,
337                cx,
338            )
339            .into();
340
341        let may_perform_edits = self
342            .tools
343            .read(cx)
344            .tool(&tool_use.name, cx)
345            .is_some_and(|tool| tool.may_perform_edits());
346
347        self.pending_tool_uses_by_id.insert(
348            tool_use.id.clone(),
349            PendingToolUse {
350                assistant_message_id,
351                id: tool_use.id,
352                name: tool_use.name.clone(),
353                ui_text: ui_text.clone(),
354                input: tool_use.input,
355                may_perform_edits,
356                status,
357            },
358        );
359
360        ui_text
361    }
362
363    pub fn run_pending_tool(
364        &mut self,
365        tool_use_id: LanguageModelToolUseId,
366        ui_text: SharedString,
367        task: Task<()>,
368    ) {
369        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
370            tool_use.ui_text = ui_text.into();
371            tool_use.status = PendingToolUseStatus::Running {
372                _task: task.shared(),
373            };
374        }
375    }
376
377    pub fn confirm_tool_use(
378        &mut self,
379        tool_use_id: LanguageModelToolUseId,
380        ui_text: impl Into<Arc<str>>,
381        input: serde_json::Value,
382        request: Arc<LanguageModelRequest>,
383        tool: Arc<dyn Tool>,
384    ) {
385        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
386            let ui_text = ui_text.into();
387            tool_use.ui_text = ui_text.clone();
388            let confirmation = Confirmation {
389                tool_use_id,
390                input,
391                request,
392                tool,
393                ui_text,
394            };
395            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
396        }
397    }
398
399    pub fn insert_tool_output(
400        &mut self,
401        tool_use_id: LanguageModelToolUseId,
402        tool_name: Arc<str>,
403        output: Result<ToolResultOutput>,
404        configured_model: Option<&ConfiguredModel>,
405        completion_mode: CompletionMode,
406    ) -> Option<PendingToolUse> {
407        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
408
409        telemetry::event!(
410            "Agent Tool Finished",
411            model = metadata
412                .as_ref()
413                .map(|metadata| metadata.model.telemetry_id()),
414            model_provider = metadata
415                .as_ref()
416                .map(|metadata| metadata.model.provider_id().to_string()),
417            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
418            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
419            tool_name,
420            success = output.is_ok()
421        );
422
423        match output {
424            Ok(output) => {
425                let tool_result = output.content;
426                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
427
428                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
429
430                // Protect from overly large output
431                let tool_output_limit = configured_model
432                    .map(|model| {
433                        model.model.max_token_count_for_mode(completion_mode.into()) as usize
434                            * BYTES_PER_TOKEN_ESTIMATE
435                    })
436                    .unwrap_or(usize::MAX);
437
438                let content = match tool_result {
439                    ToolResultContent::Text(text) => {
440                        let text = if text.len() < tool_output_limit {
441                            text
442                        } else {
443                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
444                            format!(
445                                "Tool result too long. The first {} bytes:\n\n{}",
446                                truncated.len(),
447                                truncated
448                            )
449                        };
450                        LanguageModelToolResultContent::Text(text.into())
451                    }
452                    ToolResultContent::Image(language_model_image) => {
453                        if language_model_image.estimate_tokens() < tool_output_limit {
454                            LanguageModelToolResultContent::Image(language_model_image)
455                        } else {
456                            self.tool_results.insert(
457                                tool_use_id.clone(),
458                                LanguageModelToolResult {
459                                    tool_use_id: tool_use_id.clone(),
460                                    tool_name,
461                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
462                                    is_error: true,
463                                    output: None,
464                                },
465                            );
466
467                            return old_use;
468                        }
469                    }
470                };
471
472                self.tool_results.insert(
473                    tool_use_id.clone(),
474                    LanguageModelToolResult {
475                        tool_use_id: tool_use_id.clone(),
476                        tool_name,
477                        content,
478                        is_error: false,
479                        output: output.output,
480                    },
481                );
482
483                old_use
484            }
485            Err(err) => {
486                self.tool_results.insert(
487                    tool_use_id.clone(),
488                    LanguageModelToolResult {
489                        tool_use_id: tool_use_id.clone(),
490                        tool_name,
491                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
492                        is_error: true,
493                        output: None,
494                    },
495                );
496
497                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
498                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
499                }
500
501                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
502            }
503        }
504    }
505
506    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
507        self.tool_uses_by_assistant_message
508            .contains_key(&assistant_message_id)
509    }
510
511    pub fn tool_results(
512        &self,
513        assistant_message_id: MessageId,
514    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
515        self.tool_uses_by_assistant_message
516            .get(&assistant_message_id)
517            .into_iter()
518            .flatten()
519            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
520    }
521}
522
523#[derive(Debug, Clone)]
524pub struct PendingToolUse {
525    pub id: LanguageModelToolUseId,
526    /// The ID of the Assistant message in which the tool use was requested.
527    #[allow(unused)]
528    pub assistant_message_id: MessageId,
529    pub name: Arc<str>,
530    pub ui_text: Arc<str>,
531    pub input: serde_json::Value,
532    pub status: PendingToolUseStatus,
533    pub may_perform_edits: bool,
534}
535
536#[derive(Debug, Clone)]
537pub struct Confirmation {
538    pub tool_use_id: LanguageModelToolUseId,
539    pub input: serde_json::Value,
540    pub ui_text: Arc<str>,
541    pub request: Arc<LanguageModelRequest>,
542    pub tool: Arc<dyn Tool>,
543}
544
545#[derive(Debug, Clone)]
546pub enum PendingToolUseStatus {
547    InputStillStreaming,
548    Idle,
549    NeedsConfirmation(Arc<Confirmation>),
550    Running { _task: Shared<Task<()>> },
551    Error(#[allow(unused)] Arc<str>),
552}
553
554impl PendingToolUseStatus {
555    pub fn is_idle(&self) -> bool {
556        matches!(self, PendingToolUseStatus::Idle)
557    }
558
559    pub fn is_error(&self) -> bool {
560        matches!(self, PendingToolUseStatus::Error(_))
561    }
562
563    pub fn needs_confirmation(&self) -> bool {
564        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
565    }
566}
567
568#[derive(Clone)]
569pub struct ToolUseMetadata {
570    pub model: Arc<dyn LanguageModel>,
571    pub thread_id: ThreadId,
572    pub prompt_id: PromptId,
573}