tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{
  5    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  6};
  7use collections::HashMap;
  8use futures::FutureExt as _;
  9use futures::future::Shared;
 10use gpui::{App, Entity, SharedString, Task};
 11use language_model::{
 12    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 13    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
 14};
 15use project::Project;
 16use ui::{IconName, Window};
 17use util::truncate_lines_to_byte_limit;
 18
 19use crate::thread::{MessageId, PromptId, ThreadId};
 20use crate::thread_store::SerializedMessage;
 21
 22#[derive(Debug)]
 23pub struct ToolUse {
 24    pub id: LanguageModelToolUseId,
 25    pub name: SharedString,
 26    pub ui_text: SharedString,
 27    pub status: ToolUseStatus,
 28    pub input: serde_json::Value,
 29    pub icon: ui::IconName,
 30    pub needs_confirmation: bool,
 31}
 32
 33pub struct ToolUseState {
 34    tools: Entity<ToolWorkingSet>,
 35    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_results: HashMap::default(),
 48            pending_tool_uses_by_id: HashMap::default(),
 49            tool_result_cards: HashMap::default(),
 50            tool_use_metadata_by_id: HashMap::default(),
 51        }
 52    }
 53
 54    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 55    ///
 56    /// Accepts a function to filter the tools that should be used to populate the state.
 57    pub fn from_serialized_messages(
 58        tools: Entity<ToolWorkingSet>,
 59        messages: &[SerializedMessage],
 60        project: Entity<Project>,
 61        window: &mut Window,
 62        cx: &mut App,
 63    ) -> Self {
 64        let mut this = Self::new(tools);
 65        let mut tool_names_by_id = HashMap::default();
 66
 67        for message in messages {
 68            match message.role {
 69                Role::Assistant => {
 70                    if !message.tool_uses.is_empty() {
 71                        let tool_uses = message
 72                            .tool_uses
 73                            .iter()
 74                            .map(|tool_use| LanguageModelToolUse {
 75                                id: tool_use.id.clone(),
 76                                name: tool_use.name.clone().into(),
 77                                raw_input: tool_use.input.to_string(),
 78                                input: tool_use.input.clone(),
 79                                is_input_complete: true,
 80                            })
 81                            .collect::<Vec<_>>();
 82
 83                        tool_names_by_id.extend(
 84                            tool_uses
 85                                .iter()
 86                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 87                        );
 88
 89                        this.tool_uses_by_assistant_message
 90                            .insert(message.id, tool_uses);
 91
 92                        for tool_result in &message.tool_results {
 93                            let tool_use_id = tool_result.tool_use_id.clone();
 94                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 95                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 96                                continue;
 97                            };
 98
 99                            this.tool_results.insert(
100                                tool_use_id.clone(),
101                                LanguageModelToolResult {
102                                    tool_use_id: tool_use_id.clone(),
103                                    tool_name: tool_use.clone(),
104                                    is_error: tool_result.is_error,
105                                    content: tool_result.content.clone(),
106                                    output: tool_result.output.clone(),
107                                },
108                            );
109
110                            if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
111                                if let Some(output) = tool_result.output.clone() {
112                                    if let Some(card) =
113                                        tool.deserialize_card(output, project.clone(), window, cx)
114                                    {
115                                        this.tool_result_cards.insert(tool_use_id, card);
116                                    }
117                                }
118                            }
119                        }
120                    }
121                }
122                Role::System | Role::User => {}
123            }
124        }
125
126        this
127    }
128
129    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130        let mut cancelled_tool_uses = Vec::new();
131        self.pending_tool_uses_by_id
132            .retain(|tool_use_id, tool_use| {
133                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
134                    return true;
135                }
136
137                let content = "Tool canceled by user".into();
138                self.tool_results.insert(
139                    tool_use_id.clone(),
140                    LanguageModelToolResult {
141                        tool_use_id: tool_use_id.clone(),
142                        tool_name: tool_use.name.clone(),
143                        content,
144                        output: None,
145                        is_error: true,
146                    },
147                );
148                cancelled_tool_uses.push(tool_use.clone());
149                false
150            });
151        cancelled_tool_uses
152    }
153
154    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
155        self.pending_tool_uses_by_id.values().collect()
156    }
157
158    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
159        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
160            return Vec::new();
161        };
162
163        let mut tool_uses = Vec::new();
164
165        for tool_use in tool_uses_for_message.iter() {
166            let tool_result = self.tool_results.get(&tool_use.id);
167
168            let status = (|| {
169                if let Some(tool_result) = tool_result {
170                    let content = tool_result
171                        .content
172                        .to_str()
173                        .map(|str| str.to_owned().into())
174                        .unwrap_or_default();
175
176                    return if tool_result.is_error {
177                        ToolUseStatus::Error(content)
178                    } else {
179                        ToolUseStatus::Finished(content)
180                    };
181                }
182
183                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
184                    match pending_tool_use.status {
185                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
186                        PendingToolUseStatus::NeedsConfirmation { .. } => {
187                            ToolUseStatus::NeedsConfirmation
188                        }
189                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
190                        PendingToolUseStatus::Error(ref err) => {
191                            ToolUseStatus::Error(err.clone().into())
192                        }
193                        PendingToolUseStatus::InputStillStreaming => {
194                            ToolUseStatus::InputStillStreaming
195                        }
196                    }
197                } else {
198                    ToolUseStatus::Pending
199                }
200            })();
201
202            let (icon, needs_confirmation) =
203                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
204                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
205                } else {
206                    (IconName::Cog, false)
207                };
208
209            tool_uses.push(ToolUse {
210                id: tool_use.id.clone(),
211                name: tool_use.name.clone().into(),
212                ui_text: self.tool_ui_label(
213                    &tool_use.name,
214                    &tool_use.input,
215                    tool_use.is_input_complete,
216                    cx,
217                ),
218                input: tool_use.input.clone(),
219                status,
220                icon,
221                needs_confirmation,
222            })
223        }
224
225        tool_uses
226    }
227
228    pub fn tool_ui_label(
229        &self,
230        tool_name: &str,
231        input: &serde_json::Value,
232        is_input_complete: bool,
233        cx: &App,
234    ) -> SharedString {
235        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
236            if is_input_complete {
237                tool.ui_text(input).into()
238            } else {
239                tool.still_streaming_ui_text(input).into()
240            }
241        } else {
242            format!("Unknown tool {tool_name:?}").into()
243        }
244    }
245
246    pub fn tool_results_for_message(
247        &self,
248        assistant_message_id: MessageId,
249    ) -> Vec<&LanguageModelToolResult> {
250        let Some(tool_uses) = self
251            .tool_uses_by_assistant_message
252            .get(&assistant_message_id)
253        else {
254            return Vec::new();
255        };
256
257        tool_uses
258            .iter()
259            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
260            .collect()
261    }
262
263    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
264        self.tool_uses_by_assistant_message
265            .get(&assistant_message_id)
266            .map_or(false, |results| !results.is_empty())
267    }
268
269    pub fn tool_result(
270        &self,
271        tool_use_id: &LanguageModelToolUseId,
272    ) -> Option<&LanguageModelToolResult> {
273        self.tool_results.get(tool_use_id)
274    }
275
276    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
277        self.tool_result_cards.get(tool_use_id)
278    }
279
280    pub fn insert_tool_result_card(
281        &mut self,
282        tool_use_id: LanguageModelToolUseId,
283        card: AnyToolCard,
284    ) {
285        self.tool_result_cards.insert(tool_use_id, card);
286    }
287
288    pub fn request_tool_use(
289        &mut self,
290        assistant_message_id: MessageId,
291        tool_use: LanguageModelToolUse,
292        metadata: ToolUseMetadata,
293        cx: &App,
294    ) -> Arc<str> {
295        let tool_uses = self
296            .tool_uses_by_assistant_message
297            .entry(assistant_message_id)
298            .or_default();
299
300        let mut existing_tool_use_found = false;
301
302        for existing_tool_use in tool_uses.iter_mut() {
303            if existing_tool_use.id == tool_use.id {
304                *existing_tool_use = tool_use.clone();
305                existing_tool_use_found = true;
306            }
307        }
308
309        if !existing_tool_use_found {
310            tool_uses.push(tool_use.clone());
311        }
312
313        let status = if tool_use.is_input_complete {
314            self.tool_use_metadata_by_id
315                .insert(tool_use.id.clone(), metadata);
316
317            PendingToolUseStatus::Idle
318        } else {
319            PendingToolUseStatus::InputStillStreaming
320        };
321
322        let ui_text: Arc<str> = self
323            .tool_ui_label(
324                &tool_use.name,
325                &tool_use.input,
326                tool_use.is_input_complete,
327                cx,
328            )
329            .into();
330
331        self.pending_tool_uses_by_id.insert(
332            tool_use.id.clone(),
333            PendingToolUse {
334                assistant_message_id,
335                id: tool_use.id,
336                name: tool_use.name.clone(),
337                ui_text: ui_text.clone(),
338                input: tool_use.input,
339                status,
340            },
341        );
342
343        ui_text
344    }
345
346    pub fn run_pending_tool(
347        &mut self,
348        tool_use_id: LanguageModelToolUseId,
349        ui_text: SharedString,
350        task: Task<()>,
351    ) {
352        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
353            tool_use.ui_text = ui_text.into();
354            tool_use.status = PendingToolUseStatus::Running {
355                _task: task.shared(),
356            };
357        }
358    }
359
360    pub fn confirm_tool_use(
361        &mut self,
362        tool_use_id: LanguageModelToolUseId,
363        ui_text: impl Into<Arc<str>>,
364        input: serde_json::Value,
365        request: Arc<LanguageModelRequest>,
366        tool: Arc<dyn Tool>,
367    ) {
368        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
369            let ui_text = ui_text.into();
370            tool_use.ui_text = ui_text.clone();
371            let confirmation = Confirmation {
372                tool_use_id,
373                input,
374                request,
375                tool,
376                ui_text,
377            };
378            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
379        }
380    }
381
382    pub fn insert_tool_output(
383        &mut self,
384        tool_use_id: LanguageModelToolUseId,
385        tool_name: Arc<str>,
386        output: Result<ToolResultOutput>,
387        configured_model: Option<&ConfiguredModel>,
388    ) -> Option<PendingToolUse> {
389        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
390
391        telemetry::event!(
392            "Agent Tool Finished",
393            model = metadata
394                .as_ref()
395                .map(|metadata| metadata.model.telemetry_id()),
396            model_provider = metadata
397                .as_ref()
398                .map(|metadata| metadata.model.provider_id().to_string()),
399            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
400            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
401            tool_name,
402            success = output.is_ok()
403        );
404
405        match output {
406            Ok(output) => {
407                let tool_result = output.content;
408                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
409
410                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
411
412                // Protect from overly large output
413                let tool_output_limit = configured_model
414                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
415                    .unwrap_or(usize::MAX);
416
417                let content = match tool_result {
418                    ToolResultContent::Text(text) => {
419                        let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
420
421                        LanguageModelToolResultContent::Text(
422                            format!(
423                                "Tool result too long. The first {} bytes:\n\n{}",
424                                truncated.len(),
425                                truncated
426                            )
427                            .into(),
428                        )
429                    }
430                    ToolResultContent::Image(language_model_image) => {
431                        if language_model_image.estimate_tokens() < tool_output_limit {
432                            LanguageModelToolResultContent::Image(language_model_image)
433                        } else {
434                            self.tool_results.insert(
435                                tool_use_id.clone(),
436                                LanguageModelToolResult {
437                                    tool_use_id: tool_use_id.clone(),
438                                    tool_name,
439                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
440                                    is_error: true,
441                                    output: None,
442                                },
443                            );
444
445                            return old_use;
446                        }
447                    }
448                };
449
450                self.tool_results.insert(
451                    tool_use_id.clone(),
452                    LanguageModelToolResult {
453                        tool_use_id: tool_use_id.clone(),
454                        tool_name,
455                        content,
456                        is_error: false,
457                        output: output.output,
458                    },
459                );
460
461                old_use
462            }
463            Err(err) => {
464                self.tool_results.insert(
465                    tool_use_id.clone(),
466                    LanguageModelToolResult {
467                        tool_use_id: tool_use_id.clone(),
468                        tool_name,
469                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
470                        is_error: true,
471                        output: None,
472                    },
473                );
474
475                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
476                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
477                }
478
479                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
480            }
481        }
482    }
483
484    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
485        self.tool_uses_by_assistant_message
486            .contains_key(&assistant_message_id)
487    }
488
489    pub fn tool_results(
490        &self,
491        assistant_message_id: MessageId,
492    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
493        self.tool_uses_by_assistant_message
494            .get(&assistant_message_id)
495            .into_iter()
496            .flatten()
497            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
498    }
499}
500
501#[derive(Debug, Clone)]
502pub struct PendingToolUse {
503    pub id: LanguageModelToolUseId,
504    /// The ID of the Assistant message in which the tool use was requested.
505    #[allow(unused)]
506    pub assistant_message_id: MessageId,
507    pub name: Arc<str>,
508    pub ui_text: Arc<str>,
509    pub input: serde_json::Value,
510    pub status: PendingToolUseStatus,
511}
512
513#[derive(Debug, Clone)]
514pub struct Confirmation {
515    pub tool_use_id: LanguageModelToolUseId,
516    pub input: serde_json::Value,
517    pub ui_text: Arc<str>,
518    pub request: Arc<LanguageModelRequest>,
519    pub tool: Arc<dyn Tool>,
520}
521
522#[derive(Debug, Clone)]
523pub enum PendingToolUseStatus {
524    InputStillStreaming,
525    Idle,
526    NeedsConfirmation(Arc<Confirmation>),
527    Running { _task: Shared<Task<()>> },
528    Error(#[allow(unused)] Arc<str>),
529}
530
531impl PendingToolUseStatus {
532    pub fn is_idle(&self) -> bool {
533        matches!(self, PendingToolUseStatus::Idle)
534    }
535
536    pub fn is_error(&self) -> bool {
537        matches!(self, PendingToolUseStatus::Error(_))
538    }
539
540    pub fn needs_confirmation(&self) -> bool {
541        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
542    }
543}
544
545#[derive(Clone)]
546pub struct ToolUseMetadata {
547    pub model: Arc<dyn LanguageModel>,
548    pub thread_id: ThreadId,
549    pub prompt_id: PromptId,
550}