tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 34    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 35    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 36    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 37}
 38
 39impl ToolUseState {
 40    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 41        Self {
 42            tools,
 43            tool_uses_by_assistant_message: HashMap::default(),
 44            tool_results: HashMap::default(),
 45            pending_tool_uses_by_id: HashMap::default(),
 46            tool_result_cards: HashMap::default(),
 47            tool_use_metadata_by_id: HashMap::default(),
 48        }
 49    }
 50
 51    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 52    ///
 53    /// Accepts a function to filter the tools that should be used to populate the state.
 54    pub fn from_serialized_messages(
 55        tools: Entity<ToolWorkingSet>,
 56        messages: &[SerializedMessage],
 57    ) -> Self {
 58        let mut this = Self::new(tools);
 59        let mut tool_names_by_id = HashMap::default();
 60
 61        for message in messages {
 62            match message.role {
 63                Role::Assistant => {
 64                    if !message.tool_uses.is_empty() {
 65                        let tool_uses = message
 66                            .tool_uses
 67                            .iter()
 68                            .map(|tool_use| LanguageModelToolUse {
 69                                id: tool_use.id.clone(),
 70                                name: tool_use.name.clone().into(),
 71                                raw_input: tool_use.input.to_string(),
 72                                input: tool_use.input.clone(),
 73                                is_input_complete: true,
 74                            })
 75                            .collect::<Vec<_>>();
 76
 77                        tool_names_by_id.extend(
 78                            tool_uses
 79                                .iter()
 80                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 81                        );
 82
 83                        this.tool_uses_by_assistant_message
 84                            .insert(message.id, tool_uses);
 85
 86                        for tool_result in &message.tool_results {
 87                            let tool_use_id = tool_result.tool_use_id.clone();
 88                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 89                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 90                                continue;
 91                            };
 92
 93                            this.tool_results.insert(
 94                                tool_use_id.clone(),
 95                                LanguageModelToolResult {
 96                                    tool_use_id,
 97                                    tool_name: tool_use.clone(),
 98                                    is_error: tool_result.is_error,
 99                                    content: tool_result.content.clone(),
100                                },
101                            );
102                        }
103                    }
104                }
105                Role::System | Role::User => {}
106            }
107        }
108
109        this
110    }
111
112    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
113        let mut pending_tools = Vec::new();
114        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
115            self.tool_results.insert(
116                tool_use_id.clone(),
117                LanguageModelToolResult {
118                    tool_use_id,
119                    tool_name: tool_use.name.clone(),
120                    content: "Tool canceled by user".into(),
121                    is_error: true,
122                },
123            );
124            pending_tools.push(tool_use.clone());
125        }
126        pending_tools
127    }
128
129    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
130        self.pending_tool_uses_by_id.values().collect()
131    }
132
133    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
134        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
135            return Vec::new();
136        };
137
138        let mut tool_uses = Vec::new();
139
140        for tool_use in tool_uses_for_message.iter() {
141            let tool_result = self.tool_results.get(&tool_use.id);
142
143            let status = (|| {
144                if let Some(tool_result) = tool_result {
145                    return if tool_result.is_error {
146                        ToolUseStatus::Error(tool_result.content.clone().into())
147                    } else {
148                        ToolUseStatus::Finished(tool_result.content.clone().into())
149                    };
150                }
151
152                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
153                    match pending_tool_use.status {
154                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
155                        PendingToolUseStatus::NeedsConfirmation { .. } => {
156                            ToolUseStatus::NeedsConfirmation
157                        }
158                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
159                        PendingToolUseStatus::Error(ref err) => {
160                            ToolUseStatus::Error(err.clone().into())
161                        }
162                        PendingToolUseStatus::InputStillStreaming => {
163                            ToolUseStatus::InputStillStreaming
164                        }
165                    }
166                } else {
167                    ToolUseStatus::Pending
168                }
169            })();
170
171            let (icon, needs_confirmation) =
172                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
173                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
174                } else {
175                    (IconName::Cog, false)
176                };
177
178            tool_uses.push(ToolUse {
179                id: tool_use.id.clone(),
180                name: tool_use.name.clone().into(),
181                ui_text: self.tool_ui_label(
182                    &tool_use.name,
183                    &tool_use.input,
184                    tool_use.is_input_complete,
185                    cx,
186                ),
187                input: tool_use.input.clone(),
188                status,
189                icon,
190                needs_confirmation,
191            })
192        }
193
194        tool_uses
195    }
196
197    pub fn tool_ui_label(
198        &self,
199        tool_name: &str,
200        input: &serde_json::Value,
201        is_input_complete: bool,
202        cx: &App,
203    ) -> SharedString {
204        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
205            if is_input_complete {
206                tool.ui_text(input).into()
207            } else {
208                tool.still_streaming_ui_text(input).into()
209            }
210        } else {
211            format!("Unknown tool {tool_name:?}").into()
212        }
213    }
214
215    pub fn tool_results_for_message(
216        &self,
217        assistant_message_id: MessageId,
218    ) -> Vec<&LanguageModelToolResult> {
219        let Some(tool_uses) = self
220            .tool_uses_by_assistant_message
221            .get(&assistant_message_id)
222        else {
223            return Vec::new();
224        };
225
226        tool_uses
227            .iter()
228            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
229            .collect()
230    }
231
232    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
233        self.tool_uses_by_assistant_message
234            .get(&assistant_message_id)
235            .map_or(false, |results| !results.is_empty())
236    }
237
238    pub fn tool_result(
239        &self,
240        tool_use_id: &LanguageModelToolUseId,
241    ) -> Option<&LanguageModelToolResult> {
242        self.tool_results.get(tool_use_id)
243    }
244
245    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
246        self.tool_result_cards.get(tool_use_id)
247    }
248
249    pub fn insert_tool_result_card(
250        &mut self,
251        tool_use_id: LanguageModelToolUseId,
252        card: AnyToolCard,
253    ) {
254        self.tool_result_cards.insert(tool_use_id, card);
255    }
256
257    pub fn request_tool_use(
258        &mut self,
259        assistant_message_id: MessageId,
260        tool_use: LanguageModelToolUse,
261        metadata: ToolUseMetadata,
262        cx: &App,
263    ) -> Arc<str> {
264        let tool_uses = self
265            .tool_uses_by_assistant_message
266            .entry(assistant_message_id)
267            .or_default();
268
269        let mut existing_tool_use_found = false;
270
271        for existing_tool_use in tool_uses.iter_mut() {
272            if existing_tool_use.id == tool_use.id {
273                *existing_tool_use = tool_use.clone();
274                existing_tool_use_found = true;
275            }
276        }
277
278        if !existing_tool_use_found {
279            tool_uses.push(tool_use.clone());
280        }
281
282        let status = if tool_use.is_input_complete {
283            self.tool_use_metadata_by_id
284                .insert(tool_use.id.clone(), metadata);
285
286            PendingToolUseStatus::Idle
287        } else {
288            PendingToolUseStatus::InputStillStreaming
289        };
290
291        let ui_text: Arc<str> = self
292            .tool_ui_label(
293                &tool_use.name,
294                &tool_use.input,
295                tool_use.is_input_complete,
296                cx,
297            )
298            .into();
299
300        self.pending_tool_uses_by_id.insert(
301            tool_use.id.clone(),
302            PendingToolUse {
303                assistant_message_id,
304                id: tool_use.id,
305                name: tool_use.name.clone(),
306                ui_text: ui_text.clone(),
307                input: tool_use.input,
308                status,
309            },
310        );
311
312        ui_text
313    }
314
315    pub fn run_pending_tool(
316        &mut self,
317        tool_use_id: LanguageModelToolUseId,
318        ui_text: SharedString,
319        task: Task<()>,
320    ) {
321        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
322            tool_use.ui_text = ui_text.into();
323            tool_use.status = PendingToolUseStatus::Running {
324                _task: task.shared(),
325            };
326        }
327    }
328
329    pub fn confirm_tool_use(
330        &mut self,
331        tool_use_id: LanguageModelToolUseId,
332        ui_text: impl Into<Arc<str>>,
333        input: serde_json::Value,
334        messages: Arc<Vec<LanguageModelRequestMessage>>,
335        tool: Arc<dyn Tool>,
336    ) {
337        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
338            let ui_text = ui_text.into();
339            tool_use.ui_text = ui_text.clone();
340            let confirmation = Confirmation {
341                tool_use_id,
342                input,
343                messages,
344                tool,
345                ui_text,
346            };
347            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
348        }
349    }
350
351    pub fn insert_tool_output(
352        &mut self,
353        tool_use_id: LanguageModelToolUseId,
354        tool_name: Arc<str>,
355        output: Result<String>,
356        configured_model: Option<&ConfiguredModel>,
357    ) -> Option<PendingToolUse> {
358        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
359
360        telemetry::event!(
361            "Agent Tool Finished",
362            model = metadata
363                .as_ref()
364                .map(|metadata| metadata.model.telemetry_id()),
365            model_provider = metadata
366                .as_ref()
367                .map(|metadata| metadata.model.provider_id().to_string()),
368            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
369            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
370            tool_name,
371            success = output.is_ok()
372        );
373
374        match output {
375            Ok(tool_result) => {
376                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
377
378                // Protect from clearly large output
379                let tool_output_limit = configured_model
380                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
381                    .unwrap_or(usize::MAX);
382
383                let tool_result = if tool_result.len() <= tool_output_limit {
384                    tool_result
385                } else {
386                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
387
388                    format!(
389                        "Tool result too long. The first {} bytes:\n\n{}",
390                        truncated.len(),
391                        truncated
392                    )
393                };
394
395                self.tool_results.insert(
396                    tool_use_id.clone(),
397                    LanguageModelToolResult {
398                        tool_use_id: tool_use_id.clone(),
399                        tool_name,
400                        content: tool_result.into(),
401                        is_error: false,
402                    },
403                );
404                self.pending_tool_uses_by_id.remove(&tool_use_id)
405            }
406            Err(err) => {
407                self.tool_results.insert(
408                    tool_use_id.clone(),
409                    LanguageModelToolResult {
410                        tool_use_id: tool_use_id.clone(),
411                        tool_name,
412                        content: err.to_string().into(),
413                        is_error: true,
414                    },
415                );
416
417                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
418                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
419                }
420
421                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
422            }
423        }
424    }
425
426    pub fn attach_tool_uses(
427        &self,
428        message_id: MessageId,
429        request_message: &mut LanguageModelRequestMessage,
430    ) {
431        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
432            for tool_use in tool_uses {
433                if self.tool_results.contains_key(&tool_use.id) {
434                    // Do not send tool uses until they are completed
435                    request_message
436                        .content
437                        .push(MessageContent::ToolUse(tool_use.clone()));
438                } else {
439                    log::debug!(
440                        "skipped tool use {:?} because it is still pending",
441                        tool_use
442                    );
443                }
444            }
445        }
446    }
447
448    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
449        self.tool_uses_by_assistant_message
450            .contains_key(&assistant_message_id)
451    }
452
453    pub fn tool_results_message(
454        &self,
455        assistant_message_id: MessageId,
456    ) -> Option<LanguageModelRequestMessage> {
457        let tool_uses = self
458            .tool_uses_by_assistant_message
459            .get(&assistant_message_id)?;
460
461        if tool_uses.is_empty() {
462            return None;
463        }
464
465        let mut request_message = LanguageModelRequestMessage {
466            role: Role::User,
467            content: vec![],
468            cache: false,
469        };
470
471        for tool_use in tool_uses {
472            if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
473                request_message
474                    .content
475                    .push(MessageContent::ToolResult(LanguageModelToolResult {
476                        tool_use_id: tool_use.id.clone(),
477                        tool_name: tool_result.tool_name.clone(),
478                        is_error: tool_result.is_error,
479                        content: if tool_result.content.is_empty() {
480                            // Surprisingly, the API fails if we return an empty string here.
481                            // It thinks we are sending a tool use without a tool result.
482                            "<Tool returned an empty string>".into()
483                        } else {
484                            tool_result.content.clone()
485                        },
486                    }));
487            }
488        }
489
490        Some(request_message)
491    }
492}
493
494#[derive(Debug, Clone)]
495pub struct PendingToolUse {
496    pub id: LanguageModelToolUseId,
497    /// The ID of the Assistant message in which the tool use was requested.
498    #[allow(unused)]
499    pub assistant_message_id: MessageId,
500    pub name: Arc<str>,
501    pub ui_text: Arc<str>,
502    pub input: serde_json::Value,
503    pub status: PendingToolUseStatus,
504}
505
506#[derive(Debug, Clone)]
507pub struct Confirmation {
508    pub tool_use_id: LanguageModelToolUseId,
509    pub input: serde_json::Value,
510    pub ui_text: Arc<str>,
511    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
512    pub tool: Arc<dyn Tool>,
513}
514
515#[derive(Debug, Clone)]
516pub enum PendingToolUseStatus {
517    InputStillStreaming,
518    Idle,
519    NeedsConfirmation(Arc<Confirmation>),
520    Running { _task: Shared<Task<()>> },
521    Error(#[allow(unused)] Arc<str>),
522}
523
524impl PendingToolUseStatus {
525    pub fn is_idle(&self) -> bool {
526        matches!(self, PendingToolUseStatus::Idle)
527    }
528
529    pub fn is_error(&self) -> bool {
530        matches!(self, PendingToolUseStatus::Error(_))
531    }
532
533    pub fn needs_confirmation(&self) -> bool {
534        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
535    }
536}
537
538#[derive(Clone)]
539pub struct ToolUseMetadata {
540    pub model: Arc<dyn LanguageModel>,
541    pub thread_id: ThreadId,
542    pub prompt_id: PromptId,
543}