tool_use.rs

  1use crate::{
  2    thread::{MessageId, PromptId, ThreadId},
  3    thread_store::SerializedMessage,
  4};
  5use agent_settings::CompletionMode;
  6use anyhow::Result;
  7use assistant_tool::{
  8    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  9};
 10use collections::HashMap;
 11use futures::{FutureExt as _, future::Shared};
 12use gpui::{App, Entity, SharedString, Task, Window};
 13use icons::IconName;
 14use language_model::{
 15    ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
 16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
 17    LanguageModelToolUseId, Role,
 18};
 19use project::Project;
 20use std::sync::Arc;
 21use util::truncate_lines_to_byte_limit;
 22
 23#[derive(Debug)]
 24pub struct ToolUse {
 25    pub id: LanguageModelToolUseId,
 26    pub name: SharedString,
 27    pub ui_text: SharedString,
 28    pub status: ToolUseStatus,
 29    pub input: serde_json::Value,
 30    pub icon: icons::IconName,
 31    pub needs_confirmation: bool,
 32}
 33
 34pub struct ToolUseState {
 35    tools: Entity<ToolWorkingSet>,
 36    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 37    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 38    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 39    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 40    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 41}
 42
 43impl ToolUseState {
 44    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 45        Self {
 46            tools,
 47            tool_uses_by_assistant_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    ///
 59    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 60    /// tool cards won't be deserialized
 61    pub fn from_serialized_messages(
 62        tools: Entity<ToolWorkingSet>,
 63        messages: &[SerializedMessage],
 64        project: Entity<Project>,
 65        window: Option<&mut Window>, // None in headless mode
 66        cx: &mut App,
 67    ) -> Self {
 68        let mut this = Self::new(tools);
 69        let mut tool_names_by_id = HashMap::default();
 70        let mut window = window;
 71
 72        for message in messages {
 73            match message.role {
 74                Role::Assistant => {
 75                    if !message.tool_uses.is_empty() {
 76                        let tool_uses = message
 77                            .tool_uses
 78                            .iter()
 79                            .map(|tool_use| LanguageModelToolUse {
 80                                id: tool_use.id.clone(),
 81                                name: tool_use.name.clone().into(),
 82                                raw_input: tool_use.input.to_string(),
 83                                input: tool_use.input.clone(),
 84                                is_input_complete: true,
 85                            })
 86                            .collect::<Vec<_>>();
 87
 88                        tool_names_by_id.extend(
 89                            tool_uses
 90                                .iter()
 91                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 92                        );
 93
 94                        this.tool_uses_by_assistant_message
 95                            .insert(message.id, tool_uses);
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            this.tool_results.insert(
105                                tool_use_id.clone(),
106                                LanguageModelToolResult {
107                                    tool_use_id: tool_use_id.clone(),
108                                    tool_name: tool_use.clone(),
109                                    is_error: tool_result.is_error,
110                                    content: tool_result.content.clone(),
111                                    output: tool_result.output.clone(),
112                                },
113                            );
114
115                            if let Some(window) = &mut window {
116                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
117                                    if let Some(output) = tool_result.output.clone() {
118                                        if let Some(card) = tool.deserialize_card(
119                                            output,
120                                            project.clone(),
121                                            window,
122                                            cx,
123                                        ) {
124                                            this.tool_result_cards.insert(tool_use_id, card);
125                                        }
126                                    }
127                                }
128                            }
129                        }
130                    }
131                }
132                Role::System | Role::User => {}
133            }
134        }
135
136        this
137    }
138
139    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
140        let mut canceled_tool_uses = Vec::new();
141        self.pending_tool_uses_by_id
142            .retain(|tool_use_id, tool_use| {
143                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
144                    return true;
145                }
146
147                let content = "Tool canceled by user".into();
148                self.tool_results.insert(
149                    tool_use_id.clone(),
150                    LanguageModelToolResult {
151                        tool_use_id: tool_use_id.clone(),
152                        tool_name: tool_use.name.clone(),
153                        content,
154                        output: None,
155                        is_error: true,
156                    },
157                );
158                canceled_tool_uses.push(tool_use.clone());
159                false
160            });
161        canceled_tool_uses
162    }
163
164    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
165        self.pending_tool_uses_by_id.values().collect()
166    }
167
168    pub fn tool_uses_for_message(
169        &self,
170        id: MessageId,
171        project: &Entity<Project>,
172        cx: &App,
173    ) -> Vec<ToolUse> {
174        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
175            return Vec::new();
176        };
177
178        let mut tool_uses = Vec::new();
179
180        for tool_use in tool_uses_for_message.iter() {
181            let tool_result = self.tool_results.get(&tool_use.id);
182
183            let status = (|| {
184                if let Some(tool_result) = tool_result {
185                    let content = tool_result
186                        .content
187                        .to_str()
188                        .map(|str| str.to_owned().into())
189                        .unwrap_or_default();
190
191                    return if tool_result.is_error {
192                        ToolUseStatus::Error(content)
193                    } else {
194                        ToolUseStatus::Finished(content)
195                    };
196                }
197
198                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
199                    match pending_tool_use.status {
200                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
201                        PendingToolUseStatus::NeedsConfirmation { .. } => {
202                            ToolUseStatus::NeedsConfirmation
203                        }
204                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
205                        PendingToolUseStatus::Error(ref err) => {
206                            ToolUseStatus::Error(err.clone().into())
207                        }
208                        PendingToolUseStatus::InputStillStreaming => {
209                            ToolUseStatus::InputStillStreaming
210                        }
211                    }
212                } else {
213                    ToolUseStatus::Pending
214                }
215            })();
216
217            let (icon, needs_confirmation) =
218                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
219                    (
220                        tool.icon(),
221                        tool.needs_confirmation(&tool_use.input, project, cx),
222                    )
223                } else {
224                    (IconName::Cog, false)
225                };
226
227            tool_uses.push(ToolUse {
228                id: tool_use.id.clone(),
229                name: tool_use.name.clone().into(),
230                ui_text: self.tool_ui_label(
231                    &tool_use.name,
232                    &tool_use.input,
233                    tool_use.is_input_complete,
234                    cx,
235                ),
236                input: tool_use.input.clone(),
237                status,
238                icon,
239                needs_confirmation,
240            })
241        }
242
243        tool_uses
244    }
245
246    pub fn tool_ui_label(
247        &self,
248        tool_name: &str,
249        input: &serde_json::Value,
250        is_input_complete: bool,
251        cx: &App,
252    ) -> SharedString {
253        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
254            if is_input_complete {
255                tool.ui_text(input).into()
256            } else {
257                tool.still_streaming_ui_text(input).into()
258            }
259        } else {
260            format!("Unknown tool {tool_name:?}").into()
261        }
262    }
263
264    pub fn tool_results_for_message(
265        &self,
266        assistant_message_id: MessageId,
267    ) -> Vec<&LanguageModelToolResult> {
268        let Some(tool_uses) = self
269            .tool_uses_by_assistant_message
270            .get(&assistant_message_id)
271        else {
272            return Vec::new();
273        };
274
275        tool_uses
276            .iter()
277            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
278            .collect()
279    }
280
281    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
282        self.tool_uses_by_assistant_message
283            .get(&assistant_message_id)
284            .map_or(false, |results| !results.is_empty())
285    }
286
287    pub fn tool_result(
288        &self,
289        tool_use_id: &LanguageModelToolUseId,
290    ) -> Option<&LanguageModelToolResult> {
291        self.tool_results.get(tool_use_id)
292    }
293
294    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
295        self.tool_result_cards.get(tool_use_id)
296    }
297
298    pub fn insert_tool_result_card(
299        &mut self,
300        tool_use_id: LanguageModelToolUseId,
301        card: AnyToolCard,
302    ) {
303        self.tool_result_cards.insert(tool_use_id, card);
304    }
305
306    pub fn request_tool_use(
307        &mut self,
308        assistant_message_id: MessageId,
309        tool_use: LanguageModelToolUse,
310        metadata: ToolUseMetadata,
311        cx: &App,
312    ) -> Arc<str> {
313        let tool_uses = self
314            .tool_uses_by_assistant_message
315            .entry(assistant_message_id)
316            .or_default();
317
318        let mut existing_tool_use_found = false;
319
320        for existing_tool_use in tool_uses.iter_mut() {
321            if existing_tool_use.id == tool_use.id {
322                *existing_tool_use = tool_use.clone();
323                existing_tool_use_found = true;
324            }
325        }
326
327        if !existing_tool_use_found {
328            tool_uses.push(tool_use.clone());
329        }
330
331        let status = if tool_use.is_input_complete {
332            self.tool_use_metadata_by_id
333                .insert(tool_use.id.clone(), metadata);
334
335            PendingToolUseStatus::Idle
336        } else {
337            PendingToolUseStatus::InputStillStreaming
338        };
339
340        let ui_text: Arc<str> = self
341            .tool_ui_label(
342                &tool_use.name,
343                &tool_use.input,
344                tool_use.is_input_complete,
345                cx,
346            )
347            .into();
348
349        let may_perform_edits = self
350            .tools
351            .read(cx)
352            .tool(&tool_use.name, cx)
353            .is_some_and(|tool| tool.may_perform_edits());
354
355        self.pending_tool_uses_by_id.insert(
356            tool_use.id.clone(),
357            PendingToolUse {
358                assistant_message_id,
359                id: tool_use.id,
360                name: tool_use.name.clone(),
361                ui_text: ui_text.clone(),
362                input: tool_use.input,
363                may_perform_edits,
364                status,
365            },
366        );
367
368        ui_text
369    }
370
371    pub fn run_pending_tool(
372        &mut self,
373        tool_use_id: LanguageModelToolUseId,
374        ui_text: SharedString,
375        task: Task<()>,
376    ) {
377        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
378            tool_use.ui_text = ui_text.into();
379            tool_use.status = PendingToolUseStatus::Running {
380                _task: task.shared(),
381            };
382        }
383    }
384
385    pub fn confirm_tool_use(
386        &mut self,
387        tool_use_id: LanguageModelToolUseId,
388        ui_text: impl Into<Arc<str>>,
389        input: serde_json::Value,
390        request: Arc<LanguageModelRequest>,
391        tool: Arc<dyn Tool>,
392    ) {
393        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
394            let ui_text = ui_text.into();
395            tool_use.ui_text = ui_text.clone();
396            let confirmation = Confirmation {
397                tool_use_id,
398                input,
399                request,
400                tool,
401                ui_text,
402            };
403            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
404        }
405    }
406
407    pub fn insert_tool_output(
408        &mut self,
409        tool_use_id: LanguageModelToolUseId,
410        tool_name: Arc<str>,
411        output: Result<ToolResultOutput>,
412        configured_model: Option<&ConfiguredModel>,
413        completion_mode: CompletionMode,
414    ) -> Option<PendingToolUse> {
415        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
416
417        telemetry::event!(
418            "Agent Tool Finished",
419            model = metadata
420                .as_ref()
421                .map(|metadata| metadata.model.telemetry_id()),
422            model_provider = metadata
423                .as_ref()
424                .map(|metadata| metadata.model.provider_id().to_string()),
425            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
426            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
427            tool_name,
428            success = output.is_ok()
429        );
430
431        match output {
432            Ok(output) => {
433                let tool_result = output.content;
434                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
435
436                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
437
438                // Protect from overly large output
439                let tool_output_limit = configured_model
440                    .map(|model| {
441                        model.model.max_token_count_for_mode(completion_mode.into()) as usize
442                            * BYTES_PER_TOKEN_ESTIMATE
443                    })
444                    .unwrap_or(usize::MAX);
445
446                let content = match tool_result {
447                    ToolResultContent::Text(text) => {
448                        let text = if text.len() < tool_output_limit {
449                            text
450                        } else {
451                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
452                            format!(
453                                "Tool result too long. The first {} bytes:\n\n{}",
454                                truncated.len(),
455                                truncated
456                            )
457                        };
458                        LanguageModelToolResultContent::Text(text.into())
459                    }
460                    ToolResultContent::Image(language_model_image) => {
461                        if language_model_image.estimate_tokens() < tool_output_limit {
462                            LanguageModelToolResultContent::Image(language_model_image)
463                        } else {
464                            self.tool_results.insert(
465                                tool_use_id.clone(),
466                                LanguageModelToolResult {
467                                    tool_use_id: tool_use_id.clone(),
468                                    tool_name,
469                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
470                                    is_error: true,
471                                    output: None,
472                                },
473                            );
474
475                            return old_use;
476                        }
477                    }
478                };
479
480                self.tool_results.insert(
481                    tool_use_id.clone(),
482                    LanguageModelToolResult {
483                        tool_use_id: tool_use_id.clone(),
484                        tool_name,
485                        content,
486                        is_error: false,
487                        output: output.output,
488                    },
489                );
490
491                old_use
492            }
493            Err(err) => {
494                self.tool_results.insert(
495                    tool_use_id.clone(),
496                    LanguageModelToolResult {
497                        tool_use_id: tool_use_id.clone(),
498                        tool_name,
499                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
500                        is_error: true,
501                        output: None,
502                    },
503                );
504
505                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
506                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
507                }
508
509                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
510            }
511        }
512    }
513
514    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
515        self.tool_uses_by_assistant_message
516            .contains_key(&assistant_message_id)
517    }
518
519    pub fn tool_results(
520        &self,
521        assistant_message_id: MessageId,
522    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
523        self.tool_uses_by_assistant_message
524            .get(&assistant_message_id)
525            .into_iter()
526            .flatten()
527            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
528    }
529}
530
531#[derive(Debug, Clone)]
532pub struct PendingToolUse {
533    pub id: LanguageModelToolUseId,
534    /// The ID of the Assistant message in which the tool use was requested.
535    #[allow(unused)]
536    pub assistant_message_id: MessageId,
537    pub name: Arc<str>,
538    pub ui_text: Arc<str>,
539    pub input: serde_json::Value,
540    pub status: PendingToolUseStatus,
541    pub may_perform_edits: bool,
542}
543
544#[derive(Debug, Clone)]
545pub struct Confirmation {
546    pub tool_use_id: LanguageModelToolUseId,
547    pub input: serde_json::Value,
548    pub ui_text: Arc<str>,
549    pub request: Arc<LanguageModelRequest>,
550    pub tool: Arc<dyn Tool>,
551}
552
553#[derive(Debug, Clone)]
554pub enum PendingToolUseStatus {
555    InputStillStreaming,
556    Idle,
557    NeedsConfirmation(Arc<Confirmation>),
558    Running { _task: Shared<Task<()>> },
559    Error(#[allow(unused)] Arc<str>),
560}
561
562impl PendingToolUseStatus {
563    pub fn is_idle(&self) -> bool {
564        matches!(self, PendingToolUseStatus::Idle)
565    }
566
567    pub fn is_error(&self) -> bool {
568        matches!(self, PendingToolUseStatus::Error(_))
569    }
570
571    pub fn needs_confirmation(&self) -> bool {
572        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
573    }
574}
575
576#[derive(Clone)]
577pub struct ToolUseMetadata {
578    pub model: Arc<dyn LanguageModel>,
579    pub thread_id: ThreadId,
580    pub prompt_id: PromptId,
581}