tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 34    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 35    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 36    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 37    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 38}
 39
 40impl ToolUseState {
 41    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 42        Self {
 43            tools,
 44            tool_uses_by_assistant_message: HashMap::default(),
 45            tool_uses_by_user_message: HashMap::default(),
 46            tool_results: HashMap::default(),
 47            pending_tool_uses_by_id: HashMap::default(),
 48            tool_result_cards: HashMap::default(),
 49            tool_use_metadata_by_id: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    pub fn from_serialized_messages(
 57        tools: Entity<ToolWorkingSet>,
 58        messages: &[SerializedMessage],
 59        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 60    ) -> Self {
 61        let mut this = Self::new(tools);
 62        let mut tool_names_by_id = HashMap::default();
 63
 64        for message in messages {
 65            match message.role {
 66                Role::Assistant => {
 67                    if !message.tool_uses.is_empty() {
 68                        let tool_uses = message
 69                            .tool_uses
 70                            .iter()
 71                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                raw_input: tool_use.input.to_string(),
 76                                input: tool_use.input.clone(),
 77                                is_input_complete: true,
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        tool_names_by_id.extend(
 82                            tool_uses
 83                                .iter()
 84                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 85                        );
 86
 87                        this.tool_uses_by_assistant_message
 88                            .insert(message.id, tool_uses);
 89                    }
 90                }
 91                Role::User => {
 92                    if !message.tool_results.is_empty() {
 93                        let tool_uses_by_user_message = this
 94                            .tool_uses_by_user_message
 95                            .entry(message.id)
 96                            .or_default();
 97
 98                        for tool_result in &message.tool_results {
 99                            let tool_use_id = tool_result.tool_use_id.clone();
100                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
102                                continue;
103                            };
104
105                            if !(filter_by_tool_name)(tool_use.as_ref()) {
106                                continue;
107                            }
108
109                            tool_uses_by_user_message.push(tool_use_id.clone());
110                            this.tool_results.insert(
111                                tool_use_id.clone(),
112                                LanguageModelToolResult {
113                                    tool_use_id,
114                                    tool_name: tool_use.clone(),
115                                    is_error: tool_result.is_error,
116                                    content: tool_result.content.clone(),
117                                },
118                            );
119                        }
120                    }
121                }
122                Role::System => {}
123            }
124        }
125
126        this
127    }
128
129    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130        let mut pending_tools = Vec::new();
131        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132            self.tool_results.insert(
133                tool_use_id.clone(),
134                LanguageModelToolResult {
135                    tool_use_id,
136                    tool_name: tool_use.name.clone(),
137                    content: "Tool canceled by user".into(),
138                    is_error: true,
139                },
140            );
141            pending_tools.push(tool_use.clone());
142        }
143        pending_tools
144    }
145
146    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
147        self.pending_tool_uses_by_id.values().collect()
148    }
149
150    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
151        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
152            return Vec::new();
153        };
154
155        let mut tool_uses = Vec::new();
156
157        for tool_use in tool_uses_for_message.iter() {
158            let tool_result = self.tool_results.get(&tool_use.id);
159
160            let status = (|| {
161                if let Some(tool_result) = tool_result {
162                    return if tool_result.is_error {
163                        ToolUseStatus::Error(tool_result.content.clone().into())
164                    } else {
165                        ToolUseStatus::Finished(tool_result.content.clone().into())
166                    };
167                }
168
169                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
170                    match pending_tool_use.status {
171                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
172                        PendingToolUseStatus::NeedsConfirmation { .. } => {
173                            ToolUseStatus::NeedsConfirmation
174                        }
175                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
176                        PendingToolUseStatus::Error(ref err) => {
177                            ToolUseStatus::Error(err.clone().into())
178                        }
179                        PendingToolUseStatus::InputStillStreaming => {
180                            ToolUseStatus::InputStillStreaming
181                        }
182                    }
183                } else {
184                    ToolUseStatus::Pending
185                }
186            })();
187
188            let (icon, needs_confirmation) =
189                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
190                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
191                } else {
192                    (IconName::Cog, false)
193                };
194
195            tool_uses.push(ToolUse {
196                id: tool_use.id.clone(),
197                name: tool_use.name.clone().into(),
198                ui_text: self.tool_ui_label(
199                    &tool_use.name,
200                    &tool_use.input,
201                    tool_use.is_input_complete,
202                    cx,
203                ),
204                input: tool_use.input.clone(),
205                status,
206                icon,
207                needs_confirmation,
208            })
209        }
210
211        tool_uses
212    }
213
214    pub fn tool_ui_label(
215        &self,
216        tool_name: &str,
217        input: &serde_json::Value,
218        is_input_complete: bool,
219        cx: &App,
220    ) -> SharedString {
221        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
222            if is_input_complete {
223                tool.ui_text(input).into()
224            } else {
225                tool.still_streaming_ui_text(input).into()
226            }
227        } else {
228            format!("Unknown tool {tool_name:?}").into()
229        }
230    }
231
232    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
233        let empty = Vec::new();
234
235        self.tool_uses_by_user_message
236            .get(&message_id)
237            .unwrap_or(&empty)
238            .iter()
239            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
240            .collect()
241    }
242
243    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
244        self.tool_uses_by_user_message
245            .get(&message_id)
246            .map_or(false, |results| !results.is_empty())
247    }
248
249    pub fn tool_result(
250        &self,
251        tool_use_id: &LanguageModelToolUseId,
252    ) -> Option<&LanguageModelToolResult> {
253        self.tool_results.get(tool_use_id)
254    }
255
256    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
257        self.tool_result_cards.get(tool_use_id)
258    }
259
260    pub fn insert_tool_result_card(
261        &mut self,
262        tool_use_id: LanguageModelToolUseId,
263        card: AnyToolCard,
264    ) {
265        self.tool_result_cards.insert(tool_use_id, card);
266    }
267
268    pub fn request_tool_use(
269        &mut self,
270        assistant_message_id: MessageId,
271        tool_use: LanguageModelToolUse,
272        metadata: ToolUseMetadata,
273        cx: &App,
274    ) -> Arc<str> {
275        let tool_uses = self
276            .tool_uses_by_assistant_message
277            .entry(assistant_message_id)
278            .or_default();
279
280        let mut existing_tool_use_found = false;
281
282        for existing_tool_use in tool_uses.iter_mut() {
283            if existing_tool_use.id == tool_use.id {
284                *existing_tool_use = tool_use.clone();
285                existing_tool_use_found = true;
286            }
287        }
288
289        if !existing_tool_use_found {
290            tool_uses.push(tool_use.clone());
291        }
292
293        let status = if tool_use.is_input_complete {
294            self.tool_use_metadata_by_id
295                .insert(tool_use.id.clone(), metadata);
296
297            // The tool use is being requested by the Assistant, so we want to
298            // attach the tool results to the next user message.
299            let next_user_message_id = MessageId(assistant_message_id.0 + 1);
300            self.tool_uses_by_user_message
301                .entry(next_user_message_id)
302                .or_default()
303                .push(tool_use.id.clone());
304
305            PendingToolUseStatus::Idle
306        } else {
307            PendingToolUseStatus::InputStillStreaming
308        };
309
310        let ui_text: Arc<str> = self
311            .tool_ui_label(
312                &tool_use.name,
313                &tool_use.input,
314                tool_use.is_input_complete,
315                cx,
316            )
317            .into();
318
319        self.pending_tool_uses_by_id.insert(
320            tool_use.id.clone(),
321            PendingToolUse {
322                assistant_message_id,
323                id: tool_use.id,
324                name: tool_use.name.clone(),
325                ui_text: ui_text.clone(),
326                input: tool_use.input,
327                status,
328            },
329        );
330
331        ui_text
332    }
333
334    pub fn run_pending_tool(
335        &mut self,
336        tool_use_id: LanguageModelToolUseId,
337        ui_text: SharedString,
338        task: Task<()>,
339    ) {
340        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
341            tool_use.ui_text = ui_text.into();
342            tool_use.status = PendingToolUseStatus::Running {
343                _task: task.shared(),
344            };
345        }
346    }
347
348    pub fn confirm_tool_use(
349        &mut self,
350        tool_use_id: LanguageModelToolUseId,
351        ui_text: impl Into<Arc<str>>,
352        input: serde_json::Value,
353        messages: Arc<Vec<LanguageModelRequestMessage>>,
354        tool: Arc<dyn Tool>,
355    ) {
356        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
357            let ui_text = ui_text.into();
358            tool_use.ui_text = ui_text.clone();
359            let confirmation = Confirmation {
360                tool_use_id,
361                input,
362                messages,
363                tool,
364                ui_text,
365            };
366            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
367        }
368    }
369
370    pub fn insert_tool_output(
371        &mut self,
372        tool_use_id: LanguageModelToolUseId,
373        tool_name: Arc<str>,
374        output: Result<String>,
375        cx: &App,
376    ) -> Option<PendingToolUse> {
377        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
378
379        telemetry::event!(
380            "Agent Tool Finished",
381            model = metadata
382                .as_ref()
383                .map(|metadata| metadata.model.telemetry_id()),
384            model_provider = metadata
385                .as_ref()
386                .map(|metadata| metadata.model.provider_id().to_string()),
387            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
388            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
389            tool_name,
390            success = output.is_ok()
391        );
392
393        match output {
394            Ok(tool_result) => {
395                let model_registry = LanguageModelRegistry::read_global(cx);
396
397                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
398
399                // Protect from clearly large output
400                let tool_output_limit = model_registry
401                    .default_model()
402                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
403                    .unwrap_or(usize::MAX);
404
405                let tool_result = if tool_result.len() <= tool_output_limit {
406                    tool_result
407                } else {
408                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
409
410                    format!(
411                        "Tool result too long. The first {} bytes:\n\n{}",
412                        truncated.len(),
413                        truncated
414                    )
415                };
416
417                self.tool_results.insert(
418                    tool_use_id.clone(),
419                    LanguageModelToolResult {
420                        tool_use_id: tool_use_id.clone(),
421                        tool_name,
422                        content: tool_result.into(),
423                        is_error: false,
424                    },
425                );
426                self.pending_tool_uses_by_id.remove(&tool_use_id)
427            }
428            Err(err) => {
429                self.tool_results.insert(
430                    tool_use_id.clone(),
431                    LanguageModelToolResult {
432                        tool_use_id: tool_use_id.clone(),
433                        tool_name,
434                        content: err.to_string().into(),
435                        is_error: true,
436                    },
437                );
438
439                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
440                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
441                }
442
443                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
444            }
445        }
446    }
447
448    pub fn attach_tool_uses(
449        &self,
450        message_id: MessageId,
451        request_message: &mut LanguageModelRequestMessage,
452    ) {
453        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
454            for tool_use in tool_uses {
455                if self.tool_results.contains_key(&tool_use.id) {
456                    // Do not send tool uses until they are completed
457                    request_message
458                        .content
459                        .push(MessageContent::ToolUse(tool_use.clone()));
460                } else {
461                    log::debug!(
462                        "skipped tool use {:?} because it is still pending",
463                        tool_use
464                    );
465                }
466            }
467        }
468    }
469
470    pub fn attach_tool_results(
471        &self,
472        message_id: MessageId,
473        request_message: &mut LanguageModelRequestMessage,
474    ) {
475        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
476            for tool_use_id in tool_uses {
477                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
478                    request_message.content.push(MessageContent::ToolResult(
479                        LanguageModelToolResult {
480                            tool_use_id: tool_use_id.clone(),
481                            tool_name: tool_result.tool_name.clone(),
482                            is_error: tool_result.is_error,
483                            content: if tool_result.content.is_empty() {
484                                // Surprisingly, the API fails if we return an empty string here.
485                                // It thinks we are sending a tool use without a tool result.
486                                "<Tool returned an empty string>".into()
487                            } else {
488                                tool_result.content.clone()
489                            },
490                        },
491                    ));
492                }
493            }
494        }
495    }
496}
497
498#[derive(Debug, Clone)]
499pub struct PendingToolUse {
500    pub id: LanguageModelToolUseId,
501    /// The ID of the Assistant message in which the tool use was requested.
502    #[allow(unused)]
503    pub assistant_message_id: MessageId,
504    pub name: Arc<str>,
505    pub ui_text: Arc<str>,
506    pub input: serde_json::Value,
507    pub status: PendingToolUseStatus,
508}
509
510#[derive(Debug, Clone)]
511pub struct Confirmation {
512    pub tool_use_id: LanguageModelToolUseId,
513    pub input: serde_json::Value,
514    pub ui_text: Arc<str>,
515    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
516    pub tool: Arc<dyn Tool>,
517}
518
519#[derive(Debug, Clone)]
520pub enum PendingToolUseStatus {
521    InputStillStreaming,
522    Idle,
523    NeedsConfirmation(Arc<Confirmation>),
524    Running { _task: Shared<Task<()>> },
525    Error(#[allow(unused)] Arc<str>),
526}
527
528impl PendingToolUseStatus {
529    pub fn is_idle(&self) -> bool {
530        matches!(self, PendingToolUseStatus::Idle)
531    }
532
533    pub fn is_error(&self) -> bool {
534        matches!(self, PendingToolUseStatus::Error(_))
535    }
536
537    pub fn needs_confirmation(&self) -> bool {
538        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
539    }
540}
541
542#[derive(Clone)]
543pub struct ToolUseMetadata {
544    pub model: Arc<dyn LanguageModel>,
545    pub thread_id: ThreadId,
546    pub prompt_id: PromptId,
547}