tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 34    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 35    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 36    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 37    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 38}
 39
 40impl ToolUseState {
 41    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 42        Self {
 43            tools,
 44            tool_uses_by_assistant_message: HashMap::default(),
 45            tool_uses_by_user_message: HashMap::default(),
 46            tool_results: HashMap::default(),
 47            pending_tool_uses_by_id: HashMap::default(),
 48            tool_result_cards: HashMap::default(),
 49            tool_use_metadata_by_id: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    pub fn from_serialized_messages(
 57        tools: Entity<ToolWorkingSet>,
 58        messages: &[SerializedMessage],
 59        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 60    ) -> Self {
 61        let mut this = Self::new(tools);
 62        let mut tool_names_by_id = HashMap::default();
 63
 64        for message in messages {
 65            match message.role {
 66                Role::Assistant => {
 67                    if !message.tool_uses.is_empty() {
 68                        let tool_uses = message
 69                            .tool_uses
 70                            .iter()
 71                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                input: tool_use.input.clone(),
 76                                is_input_complete: true,
 77                            })
 78                            .collect::<Vec<_>>();
 79
 80                        tool_names_by_id.extend(
 81                            tool_uses
 82                                .iter()
 83                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 84                        );
 85
 86                        this.tool_uses_by_assistant_message
 87                            .insert(message.id, tool_uses);
 88                    }
 89                }
 90                Role::User => {
 91                    if !message.tool_results.is_empty() {
 92                        let tool_uses_by_user_message = this
 93                            .tool_uses_by_user_message
 94                            .entry(message.id)
 95                            .or_default();
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            if !(filter_by_tool_name)(tool_use.as_ref()) {
105                                continue;
106                            }
107
108                            tool_uses_by_user_message.push(tool_use_id.clone());
109                            this.tool_results.insert(
110                                tool_use_id.clone(),
111                                LanguageModelToolResult {
112                                    tool_use_id,
113                                    tool_name: tool_use.clone(),
114                                    is_error: tool_result.is_error,
115                                    content: tool_result.content.clone(),
116                                },
117                            );
118                        }
119                    }
120                }
121                Role::System => {}
122            }
123        }
124
125        this
126    }
127
128    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
129        let mut pending_tools = Vec::new();
130        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
131            self.tool_results.insert(
132                tool_use_id.clone(),
133                LanguageModelToolResult {
134                    tool_use_id,
135                    tool_name: tool_use.name.clone(),
136                    content: "Tool canceled by user".into(),
137                    is_error: true,
138                },
139            );
140            pending_tools.push(tool_use.clone());
141        }
142        pending_tools
143    }
144
145    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
146        self.pending_tool_uses_by_id.values().collect()
147    }
148
149    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
150        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
151            return Vec::new();
152        };
153
154        let mut tool_uses = Vec::new();
155
156        for tool_use in tool_uses_for_message.iter() {
157            let tool_result = self.tool_results.get(&tool_use.id);
158
159            let status = (|| {
160                if let Some(tool_result) = tool_result {
161                    return if tool_result.is_error {
162                        ToolUseStatus::Error(tool_result.content.clone().into())
163                    } else {
164                        ToolUseStatus::Finished(tool_result.content.clone().into())
165                    };
166                }
167
168                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
169                    match pending_tool_use.status {
170                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
171                        PendingToolUseStatus::NeedsConfirmation { .. } => {
172                            ToolUseStatus::NeedsConfirmation
173                        }
174                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
175                        PendingToolUseStatus::Error(ref err) => {
176                            ToolUseStatus::Error(err.clone().into())
177                        }
178                        PendingToolUseStatus::InputStillStreaming => {
179                            ToolUseStatus::InputStillStreaming
180                        }
181                    }
182                } else {
183                    ToolUseStatus::Pending
184                }
185            })();
186
187            let (icon, needs_confirmation) =
188                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
189                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
190                } else {
191                    (IconName::Cog, false)
192                };
193
194            tool_uses.push(ToolUse {
195                id: tool_use.id.clone(),
196                name: tool_use.name.clone().into(),
197                ui_text: self.tool_ui_label(
198                    &tool_use.name,
199                    &tool_use.input,
200                    tool_use.is_input_complete,
201                    cx,
202                ),
203                input: tool_use.input.clone(),
204                status,
205                icon,
206                needs_confirmation,
207            })
208        }
209
210        tool_uses
211    }
212
213    pub fn tool_ui_label(
214        &self,
215        tool_name: &str,
216        input: &serde_json::Value,
217        is_input_complete: bool,
218        cx: &App,
219    ) -> SharedString {
220        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
221            if is_input_complete {
222                tool.ui_text(input).into()
223            } else {
224                tool.still_streaming_ui_text(input).into()
225            }
226        } else {
227            format!("Unknown tool {tool_name:?}").into()
228        }
229    }
230
231    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
232        let empty = Vec::new();
233
234        self.tool_uses_by_user_message
235            .get(&message_id)
236            .unwrap_or(&empty)
237            .iter()
238            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
239            .collect()
240    }
241
242    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
243        self.tool_uses_by_user_message
244            .get(&message_id)
245            .map_or(false, |results| !results.is_empty())
246    }
247
248    pub fn tool_result(
249        &self,
250        tool_use_id: &LanguageModelToolUseId,
251    ) -> Option<&LanguageModelToolResult> {
252        self.tool_results.get(tool_use_id)
253    }
254
255    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
256        self.tool_result_cards.get(tool_use_id)
257    }
258
259    pub fn insert_tool_result_card(
260        &mut self,
261        tool_use_id: LanguageModelToolUseId,
262        card: AnyToolCard,
263    ) {
264        self.tool_result_cards.insert(tool_use_id, card);
265    }
266
267    pub fn request_tool_use(
268        &mut self,
269        assistant_message_id: MessageId,
270        tool_use: LanguageModelToolUse,
271        metadata: ToolUseMetadata,
272        cx: &App,
273    ) -> Arc<str> {
274        let tool_uses = self
275            .tool_uses_by_assistant_message
276            .entry(assistant_message_id)
277            .or_default();
278
279        let mut existing_tool_use_found = false;
280
281        for existing_tool_use in tool_uses.iter_mut() {
282            if existing_tool_use.id == tool_use.id {
283                *existing_tool_use = tool_use.clone();
284                existing_tool_use_found = true;
285            }
286        }
287
288        if !existing_tool_use_found {
289            tool_uses.push(tool_use.clone());
290        }
291
292        let status = if tool_use.is_input_complete {
293            self.tool_use_metadata_by_id
294                .insert(tool_use.id.clone(), metadata);
295
296            // The tool use is being requested by the Assistant, so we want to
297            // attach the tool results to the next user message.
298            let next_user_message_id = MessageId(assistant_message_id.0 + 1);
299            self.tool_uses_by_user_message
300                .entry(next_user_message_id)
301                .or_default()
302                .push(tool_use.id.clone());
303
304            PendingToolUseStatus::Idle
305        } else {
306            PendingToolUseStatus::InputStillStreaming
307        };
308
309        let ui_text: Arc<str> = self
310            .tool_ui_label(
311                &tool_use.name,
312                &tool_use.input,
313                tool_use.is_input_complete,
314                cx,
315            )
316            .into();
317
318        self.pending_tool_uses_by_id.insert(
319            tool_use.id.clone(),
320            PendingToolUse {
321                assistant_message_id,
322                id: tool_use.id,
323                name: tool_use.name.clone(),
324                ui_text: ui_text.clone(),
325                input: tool_use.input,
326                status,
327            },
328        );
329
330        ui_text
331    }
332
333    pub fn run_pending_tool(
334        &mut self,
335        tool_use_id: LanguageModelToolUseId,
336        ui_text: SharedString,
337        task: Task<()>,
338    ) {
339        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
340            tool_use.ui_text = ui_text.into();
341            tool_use.status = PendingToolUseStatus::Running {
342                _task: task.shared(),
343            };
344        }
345    }
346
347    pub fn confirm_tool_use(
348        &mut self,
349        tool_use_id: LanguageModelToolUseId,
350        ui_text: impl Into<Arc<str>>,
351        input: serde_json::Value,
352        messages: Arc<Vec<LanguageModelRequestMessage>>,
353        tool: Arc<dyn Tool>,
354    ) {
355        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
356            let ui_text = ui_text.into();
357            tool_use.ui_text = ui_text.clone();
358            let confirmation = Confirmation {
359                tool_use_id,
360                input,
361                messages,
362                tool,
363                ui_text,
364            };
365            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
366        }
367    }
368
369    pub fn insert_tool_output(
370        &mut self,
371        tool_use_id: LanguageModelToolUseId,
372        tool_name: Arc<str>,
373        output: Result<String>,
374        cx: &App,
375    ) -> Option<PendingToolUse> {
376        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
377
378        telemetry::event!(
379            "Agent Tool Finished",
380            model = metadata
381                .as_ref()
382                .map(|metadata| metadata.model.telemetry_id()),
383            model_provider = metadata
384                .as_ref()
385                .map(|metadata| metadata.model.provider_id().to_string()),
386            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
387            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
388            tool_name,
389            success = output.is_ok()
390        );
391
392        match output {
393            Ok(tool_result) => {
394                let model_registry = LanguageModelRegistry::read_global(cx);
395
396                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
397
398                // Protect from clearly large output
399                let tool_output_limit = model_registry
400                    .default_model()
401                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
402                    .unwrap_or(usize::MAX);
403
404                let tool_result = if tool_result.len() <= tool_output_limit {
405                    tool_result
406                } else {
407                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
408
409                    format!(
410                        "Tool result too long. The first {} bytes:\n\n{}",
411                        truncated.len(),
412                        truncated
413                    )
414                };
415
416                self.tool_results.insert(
417                    tool_use_id.clone(),
418                    LanguageModelToolResult {
419                        tool_use_id: tool_use_id.clone(),
420                        tool_name,
421                        content: tool_result.into(),
422                        is_error: false,
423                    },
424                );
425                self.pending_tool_uses_by_id.remove(&tool_use_id)
426            }
427            Err(err) => {
428                self.tool_results.insert(
429                    tool_use_id.clone(),
430                    LanguageModelToolResult {
431                        tool_use_id: tool_use_id.clone(),
432                        tool_name,
433                        content: err.to_string().into(),
434                        is_error: true,
435                    },
436                );
437
438                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
439                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
440                }
441
442                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
443            }
444        }
445    }
446
447    pub fn attach_tool_uses(
448        &self,
449        message_id: MessageId,
450        request_message: &mut LanguageModelRequestMessage,
451    ) {
452        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
453            for tool_use in tool_uses {
454                if self.tool_results.contains_key(&tool_use.id) {
455                    // Do not send tool uses until they are completed
456                    request_message
457                        .content
458                        .push(MessageContent::ToolUse(tool_use.clone()));
459                } else {
460                    log::debug!(
461                        "skipped tool use {:?} because it is still pending",
462                        tool_use
463                    );
464                }
465            }
466        }
467    }
468
469    pub fn attach_tool_results(
470        &self,
471        message_id: MessageId,
472        request_message: &mut LanguageModelRequestMessage,
473    ) {
474        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
475            for tool_use_id in tool_uses {
476                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
477                    request_message.content.push(MessageContent::ToolResult(
478                        LanguageModelToolResult {
479                            tool_use_id: tool_use_id.clone(),
480                            tool_name: tool_result.tool_name.clone(),
481                            is_error: tool_result.is_error,
482                            content: if tool_result.content.is_empty() {
483                                // Surprisingly, the API fails if we return an empty string here.
484                                // It thinks we are sending a tool use without a tool result.
485                                "<Tool returned an empty string>".into()
486                            } else {
487                                tool_result.content.clone()
488                            },
489                        },
490                    ));
491                }
492            }
493        }
494    }
495}
496
497#[derive(Debug, Clone)]
498pub struct PendingToolUse {
499    pub id: LanguageModelToolUseId,
500    /// The ID of the Assistant message in which the tool use was requested.
501    #[allow(unused)]
502    pub assistant_message_id: MessageId,
503    pub name: Arc<str>,
504    pub ui_text: Arc<str>,
505    pub input: serde_json::Value,
506    pub status: PendingToolUseStatus,
507}
508
509#[derive(Debug, Clone)]
510pub struct Confirmation {
511    pub tool_use_id: LanguageModelToolUseId,
512    pub input: serde_json::Value,
513    pub ui_text: Arc<str>,
514    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
515    pub tool: Arc<dyn Tool>,
516}
517
518#[derive(Debug, Clone)]
519pub enum PendingToolUseStatus {
520    InputStillStreaming,
521    Idle,
522    NeedsConfirmation(Arc<Confirmation>),
523    Running { _task: Shared<Task<()>> },
524    Error(#[allow(unused)] Arc<str>),
525}
526
527impl PendingToolUseStatus {
528    pub fn is_idle(&self) -> bool {
529        matches!(self, PendingToolUseStatus::Idle)
530    }
531
532    pub fn is_error(&self) -> bool {
533        matches!(self, PendingToolUseStatus::Error(_))
534    }
535
536    pub fn needs_confirmation(&self) -> bool {
537        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
538    }
539}
540
541#[derive(Clone)]
542pub struct ToolUseMetadata {
543    pub model: Arc<dyn LanguageModel>,
544    pub thread_id: ThreadId,
545    pub prompt_id: PromptId,
546}