tool_use.rs

  1use crate::{
  2    thread::{MessageId, PromptId, ThreadId},
  3    thread_store::SerializedMessage,
  4};
  5use agent_settings::CompletionMode;
  6use anyhow::Result;
  7use assistant_tool::{
  8    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  9};
 10use collections::HashMap;
 11use futures::{FutureExt as _, future::Shared};
 12use gpui::{App, Entity, SharedString, Task, Window};
 13use icons::IconName;
 14use language_model::{
 15    ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
 16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
 17    LanguageModelToolUseId, Role,
 18};
 19use project::Project;
 20use std::sync::Arc;
 21use util::truncate_lines_to_byte_limit;
 22
 23#[derive(Debug)]
 24pub struct ToolUse {
 25    pub id: LanguageModelToolUseId,
 26    pub name: SharedString,
 27    pub ui_text: SharedString,
 28    pub status: ToolUseStatus,
 29    pub input: serde_json::Value,
 30    pub icon: icons::IconName,
 31    pub needs_confirmation: bool,
 32}
 33
 34pub struct ToolUseState {
 35    tools: Entity<ToolWorkingSet>,
 36    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 37    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 38    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 39    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 40    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 41}
 42
 43impl ToolUseState {
 44    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 45        Self {
 46            tools,
 47            tool_uses_by_assistant_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    ///
 59    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 60    /// tool cards won't be deserialized
 61    pub fn from_serialized_messages(
 62        tools: Entity<ToolWorkingSet>,
 63        messages: &[SerializedMessage],
 64        project: Entity<Project>,
 65        window: Option<&mut Window>, // None in headless mode
 66        cx: &mut App,
 67    ) -> Self {
 68        let mut this = Self::new(tools);
 69        let mut tool_names_by_id = HashMap::default();
 70        let mut window = window;
 71
 72        for message in messages {
 73            match message.role {
 74                Role::Assistant => {
 75                    if !message.tool_uses.is_empty() {
 76                        let tool_uses = message
 77                            .tool_uses
 78                            .iter()
 79                            .map(|tool_use| LanguageModelToolUse {
 80                                id: tool_use.id.clone(),
 81                                name: tool_use.name.clone().into(),
 82                                raw_input: tool_use.input.to_string(),
 83                                input: tool_use.input.clone(),
 84                                is_input_complete: true,
 85                            })
 86                            .collect::<Vec<_>>();
 87
 88                        tool_names_by_id.extend(
 89                            tool_uses
 90                                .iter()
 91                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 92                        );
 93
 94                        this.tool_uses_by_assistant_message
 95                            .insert(message.id, tool_uses);
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            this.tool_results.insert(
105                                tool_use_id.clone(),
106                                LanguageModelToolResult {
107                                    tool_use_id: tool_use_id.clone(),
108                                    tool_name: tool_use.clone(),
109                                    is_error: tool_result.is_error,
110                                    content: tool_result.content.clone(),
111                                    output: tool_result.output.clone(),
112                                },
113                            );
114
115                            if let Some(window) = &mut window
116                                && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
117                                    && let Some(output) = tool_result.output.clone()
118                                        && let Some(card) = tool.deserialize_card(
119                                            output,
120                                            project.clone(),
121                                            window,
122                                            cx,
123                                        ) {
124                                            this.tool_result_cards.insert(tool_use_id, card);
125                                        }
126                        }
127                    }
128                }
129                Role::System | Role::User => {}
130            }
131        }
132
133        this
134    }
135
136    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
137        let mut canceled_tool_uses = Vec::new();
138        self.pending_tool_uses_by_id
139            .retain(|tool_use_id, tool_use| {
140                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
141                    return true;
142                }
143
144                let content = "Tool canceled by user".into();
145                self.tool_results.insert(
146                    tool_use_id.clone(),
147                    LanguageModelToolResult {
148                        tool_use_id: tool_use_id.clone(),
149                        tool_name: tool_use.name.clone(),
150                        content,
151                        output: None,
152                        is_error: true,
153                    },
154                );
155                canceled_tool_uses.push(tool_use.clone());
156                false
157            });
158        canceled_tool_uses
159    }
160
161    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
162        self.pending_tool_uses_by_id.values().collect()
163    }
164
165    pub fn tool_uses_for_message(
166        &self,
167        id: MessageId,
168        project: &Entity<Project>,
169        cx: &App,
170    ) -> Vec<ToolUse> {
171        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
172            return Vec::new();
173        };
174
175        let mut tool_uses = Vec::new();
176
177        for tool_use in tool_uses_for_message.iter() {
178            let tool_result = self.tool_results.get(&tool_use.id);
179
180            let status = (|| {
181                if let Some(tool_result) = tool_result {
182                    let content = tool_result
183                        .content
184                        .to_str()
185                        .map(|str| str.to_owned().into())
186                        .unwrap_or_default();
187
188                    return if tool_result.is_error {
189                        ToolUseStatus::Error(content)
190                    } else {
191                        ToolUseStatus::Finished(content)
192                    };
193                }
194
195                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
196                    match pending_tool_use.status {
197                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
198                        PendingToolUseStatus::NeedsConfirmation { .. } => {
199                            ToolUseStatus::NeedsConfirmation
200                        }
201                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
202                        PendingToolUseStatus::Error(ref err) => {
203                            ToolUseStatus::Error(err.clone().into())
204                        }
205                        PendingToolUseStatus::InputStillStreaming => {
206                            ToolUseStatus::InputStillStreaming
207                        }
208                    }
209                } else {
210                    ToolUseStatus::Pending
211                }
212            })();
213
214            let (icon, needs_confirmation) =
215                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
216                    (
217                        tool.icon(),
218                        tool.needs_confirmation(&tool_use.input, project, cx),
219                    )
220                } else {
221                    (IconName::Cog, false)
222                };
223
224            tool_uses.push(ToolUse {
225                id: tool_use.id.clone(),
226                name: tool_use.name.clone().into(),
227                ui_text: self.tool_ui_label(
228                    &tool_use.name,
229                    &tool_use.input,
230                    tool_use.is_input_complete,
231                    cx,
232                ),
233                input: tool_use.input.clone(),
234                status,
235                icon,
236                needs_confirmation,
237            })
238        }
239
240        tool_uses
241    }
242
243    pub fn tool_ui_label(
244        &self,
245        tool_name: &str,
246        input: &serde_json::Value,
247        is_input_complete: bool,
248        cx: &App,
249    ) -> SharedString {
250        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
251            if is_input_complete {
252                tool.ui_text(input).into()
253            } else {
254                tool.still_streaming_ui_text(input).into()
255            }
256        } else {
257            format!("Unknown tool {tool_name:?}").into()
258        }
259    }
260
261    pub fn tool_results_for_message(
262        &self,
263        assistant_message_id: MessageId,
264    ) -> Vec<&LanguageModelToolResult> {
265        let Some(tool_uses) = self
266            .tool_uses_by_assistant_message
267            .get(&assistant_message_id)
268        else {
269            return Vec::new();
270        };
271
272        tool_uses
273            .iter()
274            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
275            .collect()
276    }
277
278    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
279        self.tool_uses_by_assistant_message
280            .get(&assistant_message_id)
281            .map_or(false, |results| !results.is_empty())
282    }
283
284    pub fn tool_result(
285        &self,
286        tool_use_id: &LanguageModelToolUseId,
287    ) -> Option<&LanguageModelToolResult> {
288        self.tool_results.get(tool_use_id)
289    }
290
291    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
292        self.tool_result_cards.get(tool_use_id)
293    }
294
295    pub fn insert_tool_result_card(
296        &mut self,
297        tool_use_id: LanguageModelToolUseId,
298        card: AnyToolCard,
299    ) {
300        self.tool_result_cards.insert(tool_use_id, card);
301    }
302
303    pub fn request_tool_use(
304        &mut self,
305        assistant_message_id: MessageId,
306        tool_use: LanguageModelToolUse,
307        metadata: ToolUseMetadata,
308        cx: &App,
309    ) -> Arc<str> {
310        let tool_uses = self
311            .tool_uses_by_assistant_message
312            .entry(assistant_message_id)
313            .or_default();
314
315        let mut existing_tool_use_found = false;
316
317        for existing_tool_use in tool_uses.iter_mut() {
318            if existing_tool_use.id == tool_use.id {
319                *existing_tool_use = tool_use.clone();
320                existing_tool_use_found = true;
321            }
322        }
323
324        if !existing_tool_use_found {
325            tool_uses.push(tool_use.clone());
326        }
327
328        let status = if tool_use.is_input_complete {
329            self.tool_use_metadata_by_id
330                .insert(tool_use.id.clone(), metadata);
331
332            PendingToolUseStatus::Idle
333        } else {
334            PendingToolUseStatus::InputStillStreaming
335        };
336
337        let ui_text: Arc<str> = self
338            .tool_ui_label(
339                &tool_use.name,
340                &tool_use.input,
341                tool_use.is_input_complete,
342                cx,
343            )
344            .into();
345
346        let may_perform_edits = self
347            .tools
348            .read(cx)
349            .tool(&tool_use.name, cx)
350            .is_some_and(|tool| tool.may_perform_edits());
351
352        self.pending_tool_uses_by_id.insert(
353            tool_use.id.clone(),
354            PendingToolUse {
355                assistant_message_id,
356                id: tool_use.id,
357                name: tool_use.name.clone(),
358                ui_text: ui_text.clone(),
359                input: tool_use.input,
360                may_perform_edits,
361                status,
362            },
363        );
364
365        ui_text
366    }
367
368    pub fn run_pending_tool(
369        &mut self,
370        tool_use_id: LanguageModelToolUseId,
371        ui_text: SharedString,
372        task: Task<()>,
373    ) {
374        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
375            tool_use.ui_text = ui_text.into();
376            tool_use.status = PendingToolUseStatus::Running {
377                _task: task.shared(),
378            };
379        }
380    }
381
382    pub fn confirm_tool_use(
383        &mut self,
384        tool_use_id: LanguageModelToolUseId,
385        ui_text: impl Into<Arc<str>>,
386        input: serde_json::Value,
387        request: Arc<LanguageModelRequest>,
388        tool: Arc<dyn Tool>,
389    ) {
390        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
391            let ui_text = ui_text.into();
392            tool_use.ui_text = ui_text.clone();
393            let confirmation = Confirmation {
394                tool_use_id,
395                input,
396                request,
397                tool,
398                ui_text,
399            };
400            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
401        }
402    }
403
404    pub fn insert_tool_output(
405        &mut self,
406        tool_use_id: LanguageModelToolUseId,
407        tool_name: Arc<str>,
408        output: Result<ToolResultOutput>,
409        configured_model: Option<&ConfiguredModel>,
410        completion_mode: CompletionMode,
411    ) -> Option<PendingToolUse> {
412        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
413
414        telemetry::event!(
415            "Agent Tool Finished",
416            model = metadata
417                .as_ref()
418                .map(|metadata| metadata.model.telemetry_id()),
419            model_provider = metadata
420                .as_ref()
421                .map(|metadata| metadata.model.provider_id().to_string()),
422            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
423            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
424            tool_name,
425            success = output.is_ok()
426        );
427
428        match output {
429            Ok(output) => {
430                let tool_result = output.content;
431                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
432
433                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
434
435                // Protect from overly large output
436                let tool_output_limit = configured_model
437                    .map(|model| {
438                        model.model.max_token_count_for_mode(completion_mode.into()) as usize
439                            * BYTES_PER_TOKEN_ESTIMATE
440                    })
441                    .unwrap_or(usize::MAX);
442
443                let content = match tool_result {
444                    ToolResultContent::Text(text) => {
445                        let text = if text.len() < tool_output_limit {
446                            text
447                        } else {
448                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
449                            format!(
450                                "Tool result too long. The first {} bytes:\n\n{}",
451                                truncated.len(),
452                                truncated
453                            )
454                        };
455                        LanguageModelToolResultContent::Text(text.into())
456                    }
457                    ToolResultContent::Image(language_model_image) => {
458                        if language_model_image.estimate_tokens() < tool_output_limit {
459                            LanguageModelToolResultContent::Image(language_model_image)
460                        } else {
461                            self.tool_results.insert(
462                                tool_use_id.clone(),
463                                LanguageModelToolResult {
464                                    tool_use_id: tool_use_id.clone(),
465                                    tool_name,
466                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
467                                    is_error: true,
468                                    output: None,
469                                },
470                            );
471
472                            return old_use;
473                        }
474                    }
475                };
476
477                self.tool_results.insert(
478                    tool_use_id.clone(),
479                    LanguageModelToolResult {
480                        tool_use_id: tool_use_id.clone(),
481                        tool_name,
482                        content,
483                        is_error: false,
484                        output: output.output,
485                    },
486                );
487
488                old_use
489            }
490            Err(err) => {
491                self.tool_results.insert(
492                    tool_use_id.clone(),
493                    LanguageModelToolResult {
494                        tool_use_id: tool_use_id.clone(),
495                        tool_name,
496                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
497                        is_error: true,
498                        output: None,
499                    },
500                );
501
502                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
503                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
504                }
505
506                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
507            }
508        }
509    }
510
511    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
512        self.tool_uses_by_assistant_message
513            .contains_key(&assistant_message_id)
514    }
515
516    pub fn tool_results(
517        &self,
518        assistant_message_id: MessageId,
519    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
520        self.tool_uses_by_assistant_message
521            .get(&assistant_message_id)
522            .into_iter()
523            .flatten()
524            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
525    }
526}
527
528#[derive(Debug, Clone)]
529pub struct PendingToolUse {
530    pub id: LanguageModelToolUseId,
531    /// The ID of the Assistant message in which the tool use was requested.
532    #[allow(unused)]
533    pub assistant_message_id: MessageId,
534    pub name: Arc<str>,
535    pub ui_text: Arc<str>,
536    pub input: serde_json::Value,
537    pub status: PendingToolUseStatus,
538    pub may_perform_edits: bool,
539}
540
541#[derive(Debug, Clone)]
542pub struct Confirmation {
543    pub tool_use_id: LanguageModelToolUseId,
544    pub input: serde_json::Value,
545    pub ui_text: Arc<str>,
546    pub request: Arc<LanguageModelRequest>,
547    pub tool: Arc<dyn Tool>,
548}
549
550#[derive(Debug, Clone)]
551pub enum PendingToolUseStatus {
552    InputStillStreaming,
553    Idle,
554    NeedsConfirmation(Arc<Confirmation>),
555    Running { _task: Shared<Task<()>> },
556    Error(#[allow(unused)] Arc<str>),
557}
558
559impl PendingToolUseStatus {
560    pub fn is_idle(&self) -> bool {
561        matches!(self, PendingToolUseStatus::Idle)
562    }
563
564    pub fn is_error(&self) -> bool {
565        matches!(self, PendingToolUseStatus::Error(_))
566    }
567
568    pub fn needs_confirmation(&self) -> bool {
569        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
570    }
571}
572
573#[derive(Clone)]
574pub struct ToolUseMetadata {
575    pub model: Arc<dyn LanguageModel>,
576    pub thread_id: ThreadId,
577    pub prompt_id: PromptId,
578}