tool_use.rs

  1use crate::{
  2    thread::{MessageId, PromptId, ThreadId},
  3    thread_store::SerializedMessage,
  4};
  5use anyhow::Result;
  6use assistant_tool::{
  7    AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  8};
  9use collections::HashMap;
 10use futures::{FutureExt as _, future::Shared};
 11use gpui::{App, Entity, SharedString, Task, Window};
 12use icons::IconName;
 13use language_model::{
 14    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 15    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
 16};
 17use project::Project;
 18use std::sync::Arc;
 19use util::truncate_lines_to_byte_limit;
 20
 21#[derive(Debug)]
 22pub struct ToolUse {
 23    pub id: LanguageModelToolUseId,
 24    pub name: SharedString,
 25    pub ui_text: SharedString,
 26    pub status: ToolUseStatus,
 27    pub input: serde_json::Value,
 28    pub icon: icons::IconName,
 29    pub needs_confirmation: bool,
 30}
 31
 32pub struct ToolUseState {
 33    tools: Entity<ToolWorkingSet>,
 34    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 35    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 36    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 37    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 38    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 39}
 40
 41impl ToolUseState {
 42    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 43        Self {
 44            tools,
 45            tool_uses_by_assistant_message: HashMap::default(),
 46            tool_results: HashMap::default(),
 47            pending_tool_uses_by_id: HashMap::default(),
 48            tool_result_cards: HashMap::default(),
 49            tool_use_metadata_by_id: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    ///
 57    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 58    /// tool cards won't be deserialized
 59    pub fn from_serialized_messages(
 60        tools: Entity<ToolWorkingSet>,
 61        messages: &[SerializedMessage],
 62        project: Entity<Project>,
 63        window: Option<&mut Window>, // None in headless mode
 64        cx: &mut App,
 65    ) -> Self {
 66        let mut this = Self::new(tools);
 67        let mut tool_names_by_id = HashMap::default();
 68        let mut window = window;
 69
 70        for message in messages {
 71            match message.role {
 72                Role::Assistant => {
 73                    if !message.tool_uses.is_empty() {
 74                        let tool_uses = message
 75                            .tool_uses
 76                            .iter()
 77                            .map(|tool_use| LanguageModelToolUse {
 78                                id: tool_use.id.clone(),
 79                                name: tool_use.name.clone().into(),
 80                                raw_input: tool_use.input.to_string(),
 81                                input: tool_use.input.clone(),
 82                                is_input_complete: true,
 83                            })
 84                            .collect::<Vec<_>>();
 85
 86                        tool_names_by_id.extend(
 87                            tool_uses
 88                                .iter()
 89                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 90                        );
 91
 92                        this.tool_uses_by_assistant_message
 93                            .insert(message.id, tool_uses);
 94
 95                        for tool_result in &message.tool_results {
 96                            let tool_use_id = tool_result.tool_use_id.clone();
 97                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 98                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 99                                continue;
100                            };
101
102                            this.tool_results.insert(
103                                tool_use_id.clone(),
104                                LanguageModelToolResult {
105                                    tool_use_id: tool_use_id.clone(),
106                                    tool_name: tool_use.clone(),
107                                    is_error: tool_result.is_error,
108                                    content: tool_result.content.clone(),
109                                    output: tool_result.output.clone(),
110                                },
111                            );
112
113                            if let Some(window) = &mut window {
114                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
115                                    if let Some(output) = tool_result.output.clone() {
116                                        if let Some(card) = tool.deserialize_card(
117                                            output,
118                                            project.clone(),
119                                            window,
120                                            cx,
121                                        ) {
122                                            this.tool_result_cards.insert(tool_use_id, card);
123                                        }
124                                    }
125                                }
126                            }
127                        }
128                    }
129                }
130                Role::System | Role::User => {}
131            }
132        }
133
134        this
135    }
136
137    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
138        let mut cancelled_tool_uses = Vec::new();
139        self.pending_tool_uses_by_id
140            .retain(|tool_use_id, tool_use| {
141                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
142                    return true;
143                }
144
145                let content = "Tool canceled by user".into();
146                self.tool_results.insert(
147                    tool_use_id.clone(),
148                    LanguageModelToolResult {
149                        tool_use_id: tool_use_id.clone(),
150                        tool_name: tool_use.name.clone(),
151                        content,
152                        output: None,
153                        is_error: true,
154                    },
155                );
156                cancelled_tool_uses.push(tool_use.clone());
157                false
158            });
159        cancelled_tool_uses
160    }
161
162    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
163        self.pending_tool_uses_by_id.values().collect()
164    }
165
166    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
167        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
168            return Vec::new();
169        };
170
171        let mut tool_uses = Vec::new();
172
173        for tool_use in tool_uses_for_message.iter() {
174            let tool_result = self.tool_results.get(&tool_use.id);
175
176            let status = (|| {
177                if let Some(tool_result) = tool_result {
178                    let content = tool_result
179                        .content
180                        .to_str()
181                        .map(|str| str.to_owned().into())
182                        .unwrap_or_default();
183
184                    return if tool_result.is_error {
185                        ToolUseStatus::Error(content)
186                    } else {
187                        ToolUseStatus::Finished(content)
188                    };
189                }
190
191                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
192                    match pending_tool_use.status {
193                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
194                        PendingToolUseStatus::NeedsConfirmation { .. } => {
195                            ToolUseStatus::NeedsConfirmation
196                        }
197                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
198                        PendingToolUseStatus::Error(ref err) => {
199                            ToolUseStatus::Error(err.clone().into())
200                        }
201                        PendingToolUseStatus::InputStillStreaming => {
202                            ToolUseStatus::InputStillStreaming
203                        }
204                    }
205                } else {
206                    ToolUseStatus::Pending
207                }
208            })();
209
210            let (icon, needs_confirmation) =
211                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
212                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
213                } else {
214                    (IconName::Cog, false)
215                };
216
217            tool_uses.push(ToolUse {
218                id: tool_use.id.clone(),
219                name: tool_use.name.clone().into(),
220                ui_text: self.tool_ui_label(
221                    &tool_use.name,
222                    &tool_use.input,
223                    tool_use.is_input_complete,
224                    cx,
225                ),
226                input: tool_use.input.clone(),
227                status,
228                icon,
229                needs_confirmation,
230            })
231        }
232
233        tool_uses
234    }
235
236    pub fn tool_ui_label(
237        &self,
238        tool_name: &str,
239        input: &serde_json::Value,
240        is_input_complete: bool,
241        cx: &App,
242    ) -> SharedString {
243        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
244            if is_input_complete {
245                tool.ui_text(input).into()
246            } else {
247                tool.still_streaming_ui_text(input).into()
248            }
249        } else {
250            format!("Unknown tool {tool_name:?}").into()
251        }
252    }
253
254    pub fn tool_results_for_message(
255        &self,
256        assistant_message_id: MessageId,
257    ) -> Vec<&LanguageModelToolResult> {
258        let Some(tool_uses) = self
259            .tool_uses_by_assistant_message
260            .get(&assistant_message_id)
261        else {
262            return Vec::new();
263        };
264
265        tool_uses
266            .iter()
267            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
268            .collect()
269    }
270
271    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
272        self.tool_uses_by_assistant_message
273            .get(&assistant_message_id)
274            .map_or(false, |results| !results.is_empty())
275    }
276
277    pub fn tool_result(
278        &self,
279        tool_use_id: &LanguageModelToolUseId,
280    ) -> Option<&LanguageModelToolResult> {
281        self.tool_results.get(tool_use_id)
282    }
283
284    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
285        self.tool_result_cards.get(tool_use_id)
286    }
287
288    pub fn insert_tool_result_card(
289        &mut self,
290        tool_use_id: LanguageModelToolUseId,
291        card: AnyToolCard,
292    ) {
293        self.tool_result_cards.insert(tool_use_id, card);
294    }
295
296    pub fn request_tool_use(
297        &mut self,
298        assistant_message_id: MessageId,
299        tool_use: LanguageModelToolUse,
300        metadata: ToolUseMetadata,
301        cx: &App,
302    ) -> Arc<str> {
303        let tool_uses = self
304            .tool_uses_by_assistant_message
305            .entry(assistant_message_id)
306            .or_default();
307
308        let mut existing_tool_use_found = false;
309
310        for existing_tool_use in tool_uses.iter_mut() {
311            if existing_tool_use.id == tool_use.id {
312                *existing_tool_use = tool_use.clone();
313                existing_tool_use_found = true;
314            }
315        }
316
317        if !existing_tool_use_found {
318            tool_uses.push(tool_use.clone());
319        }
320
321        let status = if tool_use.is_input_complete {
322            self.tool_use_metadata_by_id
323                .insert(tool_use.id.clone(), metadata);
324
325            PendingToolUseStatus::Idle
326        } else {
327            PendingToolUseStatus::InputStillStreaming
328        };
329
330        let ui_text: Arc<str> = self
331            .tool_ui_label(
332                &tool_use.name,
333                &tool_use.input,
334                tool_use.is_input_complete,
335                cx,
336            )
337            .into();
338
339        let may_perform_edits = self
340            .tools
341            .read(cx)
342            .tool(&tool_use.name, cx)
343            .is_some_and(|tool| tool.may_perform_edits());
344
345        self.pending_tool_uses_by_id.insert(
346            tool_use.id.clone(),
347            PendingToolUse {
348                assistant_message_id,
349                id: tool_use.id,
350                name: tool_use.name.clone(),
351                ui_text: ui_text.clone(),
352                input: tool_use.input,
353                may_perform_edits,
354                status,
355            },
356        );
357
358        ui_text
359    }
360
361    pub fn run_pending_tool(
362        &mut self,
363        tool_use_id: LanguageModelToolUseId,
364        ui_text: SharedString,
365        task: Task<()>,
366    ) {
367        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
368            tool_use.ui_text = ui_text.into();
369            tool_use.status = PendingToolUseStatus::Running {
370                _task: task.shared(),
371            };
372        }
373    }
374
375    pub fn confirm_tool_use(
376        &mut self,
377        tool_use_id: LanguageModelToolUseId,
378        ui_text: impl Into<Arc<str>>,
379        input: serde_json::Value,
380        request: Arc<LanguageModelRequest>,
381        tool: AnyTool,
382    ) {
383        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
384            let ui_text = ui_text.into();
385            tool_use.ui_text = ui_text.clone();
386            let confirmation = Confirmation {
387                tool_use_id,
388                input,
389                request,
390                tool,
391                ui_text,
392            };
393            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
394        }
395    }
396
397    pub fn insert_tool_output(
398        &mut self,
399        tool_use_id: LanguageModelToolUseId,
400        tool_name: Arc<str>,
401        output: Result<ToolResultOutput>,
402        configured_model: Option<&ConfiguredModel>,
403    ) -> Option<PendingToolUse> {
404        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
405
406        telemetry::event!(
407            "Agent Tool Finished",
408            model = metadata
409                .as_ref()
410                .map(|metadata| metadata.model.telemetry_id()),
411            model_provider = metadata
412                .as_ref()
413                .map(|metadata| metadata.model.provider_id().to_string()),
414            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
415            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
416            tool_name,
417            success = output.is_ok()
418        );
419
420        match output {
421            Ok(output) => {
422                let tool_result = output.content;
423                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
424
425                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
426
427                // Protect from overly large output
428                let tool_output_limit = configured_model
429                    .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
430                    .unwrap_or(usize::MAX);
431
432                let content = match tool_result {
433                    ToolResultContent::Text(text) => {
434                        let text = if text.len() < tool_output_limit {
435                            text
436                        } else {
437                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
438                            format!(
439                                "Tool result too long. The first {} bytes:\n\n{}",
440                                truncated.len(),
441                                truncated
442                            )
443                        };
444                        LanguageModelToolResultContent::Text(text.into())
445                    }
446                    ToolResultContent::Image(language_model_image) => {
447                        if language_model_image.estimate_tokens() < tool_output_limit {
448                            LanguageModelToolResultContent::Image(language_model_image)
449                        } else {
450                            self.tool_results.insert(
451                                tool_use_id.clone(),
452                                LanguageModelToolResult {
453                                    tool_use_id: tool_use_id.clone(),
454                                    tool_name,
455                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
456                                    is_error: true,
457                                    output: None,
458                                },
459                            );
460
461                            return old_use;
462                        }
463                    }
464                };
465
466                self.tool_results.insert(
467                    tool_use_id.clone(),
468                    LanguageModelToolResult {
469                        tool_use_id: tool_use_id.clone(),
470                        tool_name,
471                        content,
472                        is_error: false,
473                        output: output.output,
474                    },
475                );
476
477                old_use
478            }
479            Err(err) => {
480                self.tool_results.insert(
481                    tool_use_id.clone(),
482                    LanguageModelToolResult {
483                        tool_use_id: tool_use_id.clone(),
484                        tool_name,
485                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
486                        is_error: true,
487                        output: None,
488                    },
489                );
490
491                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
492                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
493                }
494
495                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
496            }
497        }
498    }
499
500    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
501        self.tool_uses_by_assistant_message
502            .contains_key(&assistant_message_id)
503    }
504
505    pub fn tool_results(
506        &self,
507        assistant_message_id: MessageId,
508    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
509        self.tool_uses_by_assistant_message
510            .get(&assistant_message_id)
511            .into_iter()
512            .flatten()
513            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
514    }
515}
516
517#[derive(Debug, Clone)]
518pub struct PendingToolUse {
519    pub id: LanguageModelToolUseId,
520    /// The ID of the Assistant message in which the tool use was requested.
521    #[allow(unused)]
522    pub assistant_message_id: MessageId,
523    pub name: Arc<str>,
524    pub ui_text: Arc<str>,
525    pub input: serde_json::Value,
526    pub status: PendingToolUseStatus,
527    pub may_perform_edits: bool,
528}
529
530#[derive(Debug, Clone)]
531pub struct Confirmation {
532    pub tool_use_id: LanguageModelToolUseId,
533    pub input: serde_json::Value,
534    pub ui_text: Arc<str>,
535    pub request: Arc<LanguageModelRequest>,
536    pub tool: AnyTool,
537}
538
539#[derive(Debug, Clone)]
540pub enum PendingToolUseStatus {
541    InputStillStreaming,
542    Idle,
543    NeedsConfirmation(Arc<Confirmation>),
544    Running { _task: Shared<Task<()>> },
545    Error(#[allow(unused)] Arc<str>),
546}
547
548impl PendingToolUseStatus {
549    pub fn is_idle(&self) -> bool {
550        matches!(self, PendingToolUseStatus::Idle)
551    }
552
553    pub fn is_error(&self) -> bool {
554        matches!(self, PendingToolUseStatus::Error(_))
555    }
556
557    pub fn needs_confirmation(&self) -> bool {
558        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
559    }
560}
561
562#[derive(Clone)]
563pub struct ToolUseMetadata {
564    pub model: Arc<dyn LanguageModel>,
565    pub thread_id: ThreadId,
566    pub prompt_id: PromptId,
567}