tool_use.rs

  1use crate::{
  2    thread::{MessageId, PromptId, ThreadId},
  3    thread_store::SerializedMessage,
  4};
  5use agent_settings::CompletionMode;
  6use anyhow::Result;
  7use assistant_tool::{
  8    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  9};
 10use collections::HashMap;
 11use futures::{FutureExt as _, future::Shared};
 12use gpui::{App, Entity, SharedString, Task, Window};
 13use icons::IconName;
 14use language_model::{
 15    ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
 16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
 17    LanguageModelToolUseId, Role,
 18};
 19use project::Project;
 20use std::sync::Arc;
 21use util::truncate_lines_to_byte_limit;
 22
 23#[derive(Debug)]
 24pub struct ToolUse {
 25    pub id: LanguageModelToolUseId,
 26    pub name: SharedString,
 27    pub ui_text: SharedString,
 28    pub status: ToolUseStatus,
 29    pub input: serde_json::Value,
 30    pub icon: icons::IconName,
 31    pub needs_confirmation: bool,
 32}
 33
 34pub struct ToolUseState {
 35    tools: Entity<ToolWorkingSet>,
 36    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 37    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 38    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 39    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 40    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 41}
 42
 43impl ToolUseState {
 44    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 45        Self {
 46            tools,
 47            tool_uses_by_assistant_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    ///
 59    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 60    /// tool cards won't be deserialized
 61    pub fn from_serialized_messages(
 62        tools: Entity<ToolWorkingSet>,
 63        messages: &[SerializedMessage],
 64        project: Entity<Project>,
 65        window: Option<&mut Window>, // None in headless mode
 66        cx: &mut App,
 67    ) -> Self {
 68        let mut this = Self::new(tools);
 69        let mut tool_names_by_id = HashMap::default();
 70        let mut window = window;
 71
 72        for message in messages {
 73            match message.role {
 74                Role::Assistant => {
 75                    if !message.tool_uses.is_empty() {
 76                        let tool_uses = message
 77                            .tool_uses
 78                            .iter()
 79                            .map(|tool_use| LanguageModelToolUse {
 80                                id: tool_use.id.clone(),
 81                                name: tool_use.name.clone().into(),
 82                                raw_input: tool_use.input.to_string(),
 83                                input: tool_use.input.clone(),
 84                                is_input_complete: true,
 85                            })
 86                            .collect::<Vec<_>>();
 87
 88                        tool_names_by_id.extend(
 89                            tool_uses
 90                                .iter()
 91                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 92                        );
 93
 94                        this.tool_uses_by_assistant_message
 95                            .insert(message.id, tool_uses);
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            this.tool_results.insert(
105                                tool_use_id.clone(),
106                                LanguageModelToolResult {
107                                    tool_use_id: tool_use_id.clone(),
108                                    tool_name: tool_use.clone(),
109                                    is_error: tool_result.is_error,
110                                    content: tool_result.content.clone(),
111                                    output: tool_result.output.clone(),
112                                },
113                            );
114
115                            if let Some(window) = &mut window
116                                && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
117                                && let Some(output) = tool_result.output.clone()
118                                && let Some(card) =
119                                    tool.deserialize_card(output, project.clone(), window, cx)
120                            {
121                                this.tool_result_cards.insert(tool_use_id, card);
122                            }
123                        }
124                    }
125                }
126                Role::System | Role::User => {}
127            }
128        }
129
130        this
131    }
132
133    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
134        let mut canceled_tool_uses = Vec::new();
135        self.pending_tool_uses_by_id
136            .retain(|tool_use_id, tool_use| {
137                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
138                    return true;
139                }
140
141                let content = "Tool canceled by user".into();
142                self.tool_results.insert(
143                    tool_use_id.clone(),
144                    LanguageModelToolResult {
145                        tool_use_id: tool_use_id.clone(),
146                        tool_name: tool_use.name.clone(),
147                        content,
148                        output: None,
149                        is_error: true,
150                    },
151                );
152                canceled_tool_uses.push(tool_use.clone());
153                false
154            });
155        canceled_tool_uses
156    }
157
158    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
159        self.pending_tool_uses_by_id.values().collect()
160    }
161
162    pub fn tool_uses_for_message(
163        &self,
164        id: MessageId,
165        project: &Entity<Project>,
166        cx: &App,
167    ) -> Vec<ToolUse> {
168        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169            return Vec::new();
170        };
171
172        let mut tool_uses = Vec::new();
173
174        for tool_use in tool_uses_for_message.iter() {
175            let tool_result = self.tool_results.get(&tool_use.id);
176
177            let status = (|| {
178                if let Some(tool_result) = tool_result {
179                    let content = tool_result
180                        .content
181                        .to_str()
182                        .map(|str| str.to_owned().into())
183                        .unwrap_or_default();
184
185                    return if tool_result.is_error {
186                        ToolUseStatus::Error(content)
187                    } else {
188                        ToolUseStatus::Finished(content)
189                    };
190                }
191
192                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193                    match pending_tool_use.status {
194                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195                        PendingToolUseStatus::NeedsConfirmation { .. } => {
196                            ToolUseStatus::NeedsConfirmation
197                        }
198                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199                        PendingToolUseStatus::Error(ref err) => {
200                            ToolUseStatus::Error(err.clone().into())
201                        }
202                        PendingToolUseStatus::InputStillStreaming => {
203                            ToolUseStatus::InputStillStreaming
204                        }
205                    }
206                } else {
207                    ToolUseStatus::Pending
208                }
209            })();
210
211            let (icon, needs_confirmation) =
212                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213                    (
214                        tool.icon(),
215                        tool.needs_confirmation(&tool_use.input, project, cx),
216                    )
217                } else {
218                    (IconName::Cog, false)
219                };
220
221            tool_uses.push(ToolUse {
222                id: tool_use.id.clone(),
223                name: tool_use.name.clone().into(),
224                ui_text: self.tool_ui_label(
225                    &tool_use.name,
226                    &tool_use.input,
227                    tool_use.is_input_complete,
228                    cx,
229                ),
230                input: tool_use.input.clone(),
231                status,
232                icon,
233                needs_confirmation,
234            })
235        }
236
237        tool_uses
238    }
239
240    pub fn tool_ui_label(
241        &self,
242        tool_name: &str,
243        input: &serde_json::Value,
244        is_input_complete: bool,
245        cx: &App,
246    ) -> SharedString {
247        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
248            if is_input_complete {
249                tool.ui_text(input).into()
250            } else {
251                tool.still_streaming_ui_text(input).into()
252            }
253        } else {
254            format!("Unknown tool {tool_name:?}").into()
255        }
256    }
257
258    pub fn tool_results_for_message(
259        &self,
260        assistant_message_id: MessageId,
261    ) -> Vec<&LanguageModelToolResult> {
262        let Some(tool_uses) = self
263            .tool_uses_by_assistant_message
264            .get(&assistant_message_id)
265        else {
266            return Vec::new();
267        };
268
269        tool_uses
270            .iter()
271            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
272            .collect()
273    }
274
275    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
276        self.tool_uses_by_assistant_message
277            .get(&assistant_message_id)
278            .is_some_and(|results| !results.is_empty())
279    }
280
281    pub fn tool_result(
282        &self,
283        tool_use_id: &LanguageModelToolUseId,
284    ) -> Option<&LanguageModelToolResult> {
285        self.tool_results.get(tool_use_id)
286    }
287
288    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
289        self.tool_result_cards.get(tool_use_id)
290    }
291
292    pub fn insert_tool_result_card(
293        &mut self,
294        tool_use_id: LanguageModelToolUseId,
295        card: AnyToolCard,
296    ) {
297        self.tool_result_cards.insert(tool_use_id, card);
298    }
299
300    pub fn request_tool_use(
301        &mut self,
302        assistant_message_id: MessageId,
303        tool_use: LanguageModelToolUse,
304        metadata: ToolUseMetadata,
305        cx: &App,
306    ) -> Arc<str> {
307        let tool_uses = self
308            .tool_uses_by_assistant_message
309            .entry(assistant_message_id)
310            .or_default();
311
312        let mut existing_tool_use_found = false;
313
314        for existing_tool_use in tool_uses.iter_mut() {
315            if existing_tool_use.id == tool_use.id {
316                *existing_tool_use = tool_use.clone();
317                existing_tool_use_found = true;
318            }
319        }
320
321        if !existing_tool_use_found {
322            tool_uses.push(tool_use.clone());
323        }
324
325        let status = if tool_use.is_input_complete {
326            self.tool_use_metadata_by_id
327                .insert(tool_use.id.clone(), metadata);
328
329            PendingToolUseStatus::Idle
330        } else {
331            PendingToolUseStatus::InputStillStreaming
332        };
333
334        let ui_text: Arc<str> = self
335            .tool_ui_label(
336                &tool_use.name,
337                &tool_use.input,
338                tool_use.is_input_complete,
339                cx,
340            )
341            .into();
342
343        let may_perform_edits = self
344            .tools
345            .read(cx)
346            .tool(&tool_use.name, cx)
347            .is_some_and(|tool| tool.may_perform_edits());
348
349        self.pending_tool_uses_by_id.insert(
350            tool_use.id.clone(),
351            PendingToolUse {
352                assistant_message_id,
353                id: tool_use.id,
354                name: tool_use.name.clone(),
355                ui_text: ui_text.clone(),
356                input: tool_use.input,
357                may_perform_edits,
358                status,
359            },
360        );
361
362        ui_text
363    }
364
365    pub fn run_pending_tool(
366        &mut self,
367        tool_use_id: LanguageModelToolUseId,
368        ui_text: SharedString,
369        task: Task<()>,
370    ) {
371        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
372            tool_use.ui_text = ui_text.into();
373            tool_use.status = PendingToolUseStatus::Running {
374                _task: task.shared(),
375            };
376        }
377    }
378
379    pub fn confirm_tool_use(
380        &mut self,
381        tool_use_id: LanguageModelToolUseId,
382        ui_text: impl Into<Arc<str>>,
383        input: serde_json::Value,
384        request: Arc<LanguageModelRequest>,
385        tool: Arc<dyn Tool>,
386    ) {
387        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
388            let ui_text = ui_text.into();
389            tool_use.ui_text = ui_text.clone();
390            let confirmation = Confirmation {
391                tool_use_id,
392                input,
393                request,
394                tool,
395                ui_text,
396            };
397            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
398        }
399    }
400
401    pub fn insert_tool_output(
402        &mut self,
403        tool_use_id: LanguageModelToolUseId,
404        tool_name: Arc<str>,
405        output: Result<ToolResultOutput>,
406        configured_model: Option<&ConfiguredModel>,
407        completion_mode: CompletionMode,
408    ) -> Option<PendingToolUse> {
409        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
410
411        telemetry::event!(
412            "Agent Tool Finished",
413            model = metadata
414                .as_ref()
415                .map(|metadata| metadata.model.telemetry_id()),
416            model_provider = metadata
417                .as_ref()
418                .map(|metadata| metadata.model.provider_id().to_string()),
419            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
420            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
421            tool_name,
422            success = output.is_ok()
423        );
424
425        match output {
426            Ok(output) => {
427                let tool_result = output.content;
428                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
429
430                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
431
432                // Protect from overly large output
433                let tool_output_limit = configured_model
434                    .map(|model| {
435                        model.model.max_token_count_for_mode(completion_mode.into()) as usize
436                            * BYTES_PER_TOKEN_ESTIMATE
437                    })
438                    .unwrap_or(usize::MAX);
439
440                let content = match tool_result {
441                    ToolResultContent::Text(text) => {
442                        let text = if text.len() < tool_output_limit {
443                            text
444                        } else {
445                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
446                            format!(
447                                "Tool result too long. The first {} bytes:\n\n{}",
448                                truncated.len(),
449                                truncated
450                            )
451                        };
452                        LanguageModelToolResultContent::Text(text.into())
453                    }
454                    ToolResultContent::Image(language_model_image) => {
455                        if language_model_image.estimate_tokens() < tool_output_limit {
456                            LanguageModelToolResultContent::Image(language_model_image)
457                        } else {
458                            self.tool_results.insert(
459                                tool_use_id.clone(),
460                                LanguageModelToolResult {
461                                    tool_use_id: tool_use_id.clone(),
462                                    tool_name,
463                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
464                                    is_error: true,
465                                    output: None,
466                                },
467                            );
468
469                            return old_use;
470                        }
471                    }
472                };
473
474                self.tool_results.insert(
475                    tool_use_id.clone(),
476                    LanguageModelToolResult {
477                        tool_use_id: tool_use_id.clone(),
478                        tool_name,
479                        content,
480                        is_error: false,
481                        output: output.output,
482                    },
483                );
484
485                old_use
486            }
487            Err(err) => {
488                self.tool_results.insert(
489                    tool_use_id.clone(),
490                    LanguageModelToolResult {
491                        tool_use_id: tool_use_id.clone(),
492                        tool_name,
493                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
494                        is_error: true,
495                        output: None,
496                    },
497                );
498
499                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
500                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
501                }
502
503                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
504            }
505        }
506    }
507
508    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
509        self.tool_uses_by_assistant_message
510            .contains_key(&assistant_message_id)
511    }
512
513    pub fn tool_results(
514        &self,
515        assistant_message_id: MessageId,
516    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
517        self.tool_uses_by_assistant_message
518            .get(&assistant_message_id)
519            .into_iter()
520            .flatten()
521            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
522    }
523}
524
525#[derive(Debug, Clone)]
526pub struct PendingToolUse {
527    pub id: LanguageModelToolUseId,
528    /// The ID of the Assistant message in which the tool use was requested.
529    #[allow(unused)]
530    pub assistant_message_id: MessageId,
531    pub name: Arc<str>,
532    pub ui_text: Arc<str>,
533    pub input: serde_json::Value,
534    pub status: PendingToolUseStatus,
535    pub may_perform_edits: bool,
536}
537
538#[derive(Debug, Clone)]
539pub struct Confirmation {
540    pub tool_use_id: LanguageModelToolUseId,
541    pub input: serde_json::Value,
542    pub ui_text: Arc<str>,
543    pub request: Arc<LanguageModelRequest>,
544    pub tool: Arc<dyn Tool>,
545}
546
547#[derive(Debug, Clone)]
548pub enum PendingToolUseStatus {
549    InputStillStreaming,
550    Idle,
551    NeedsConfirmation(Arc<Confirmation>),
552    Running { _task: Shared<Task<()>> },
553    Error(#[allow(unused)] Arc<str>),
554}
555
556impl PendingToolUseStatus {
557    pub fn is_idle(&self) -> bool {
558        matches!(self, PendingToolUseStatus::Idle)
559    }
560
561    pub fn is_error(&self) -> bool {
562        matches!(self, PendingToolUseStatus::Error(_))
563    }
564
565    pub fn needs_confirmation(&self) -> bool {
566        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
567    }
568}
569
570#[derive(Clone)]
571pub struct ToolUseMetadata {
572    pub model: Arc<dyn LanguageModel>,
573    pub thread_id: ThreadId,
574    pub prompt_id: PromptId,
575}