tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 34    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 35    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 36    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 37}
 38
 39impl ToolUseState {
 40    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 41        Self {
 42            tools,
 43            tool_uses_by_assistant_message: HashMap::default(),
 44            tool_results: HashMap::default(),
 45            pending_tool_uses_by_id: HashMap::default(),
 46            tool_result_cards: HashMap::default(),
 47            tool_use_metadata_by_id: HashMap::default(),
 48        }
 49    }
 50
 51    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 52    ///
 53    /// Accepts a function to filter the tools that should be used to populate the state.
 54    pub fn from_serialized_messages(
 55        tools: Entity<ToolWorkingSet>,
 56        messages: &[SerializedMessage],
 57    ) -> Self {
 58        let mut this = Self::new(tools);
 59        let mut tool_names_by_id = HashMap::default();
 60
 61        for message in messages {
 62            match message.role {
 63                Role::Assistant => {
 64                    if !message.tool_uses.is_empty() {
 65                        let tool_uses = message
 66                            .tool_uses
 67                            .iter()
 68                            .map(|tool_use| LanguageModelToolUse {
 69                                id: tool_use.id.clone(),
 70                                name: tool_use.name.clone().into(),
 71                                raw_input: tool_use.input.to_string(),
 72                                input: tool_use.input.clone(),
 73                                is_input_complete: true,
 74                            })
 75                            .collect::<Vec<_>>();
 76
 77                        tool_names_by_id.extend(
 78                            tool_uses
 79                                .iter()
 80                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 81                        );
 82
 83                        this.tool_uses_by_assistant_message
 84                            .insert(message.id, tool_uses);
 85
 86                        for tool_result in &message.tool_results {
 87                            let tool_use_id = tool_result.tool_use_id.clone();
 88                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 89                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 90                                continue;
 91                            };
 92
 93                            this.tool_results.insert(
 94                                tool_use_id.clone(),
 95                                LanguageModelToolResult {
 96                                    tool_use_id,
 97                                    tool_name: tool_use.clone(),
 98                                    is_error: tool_result.is_error,
 99                                    content: tool_result.content.clone(),
100                                },
101                            );
102                        }
103                    }
104                }
105                Role::System | Role::User => {}
106            }
107        }
108
109        this
110    }
111
112    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
113        let mut cancelled_tool_uses = Vec::new();
114        self.pending_tool_uses_by_id
115            .retain(|tool_use_id, tool_use| {
116                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
117                    return true;
118                }
119
120                let content = "Tool canceled by user".into();
121                self.tool_results.insert(
122                    tool_use_id.clone(),
123                    LanguageModelToolResult {
124                        tool_use_id: tool_use_id.clone(),
125                        tool_name: tool_use.name.clone(),
126                        content,
127                        is_error: true,
128                    },
129                );
130                cancelled_tool_uses.push(tool_use.clone());
131                false
132            });
133        cancelled_tool_uses
134    }
135
136    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
137        self.pending_tool_uses_by_id.values().collect()
138    }
139
140    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
141        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
142            return Vec::new();
143        };
144
145        let mut tool_uses = Vec::new();
146
147        for tool_use in tool_uses_for_message.iter() {
148            let tool_result = self.tool_results.get(&tool_use.id);
149
150            let status = (|| {
151                if let Some(tool_result) = tool_result {
152                    return if tool_result.is_error {
153                        ToolUseStatus::Error(tool_result.content.clone().into())
154                    } else {
155                        ToolUseStatus::Finished(tool_result.content.clone().into())
156                    };
157                }
158
159                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
160                    match pending_tool_use.status {
161                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
162                        PendingToolUseStatus::NeedsConfirmation { .. } => {
163                            ToolUseStatus::NeedsConfirmation
164                        }
165                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
166                        PendingToolUseStatus::Error(ref err) => {
167                            ToolUseStatus::Error(err.clone().into())
168                        }
169                        PendingToolUseStatus::InputStillStreaming => {
170                            ToolUseStatus::InputStillStreaming
171                        }
172                    }
173                } else {
174                    ToolUseStatus::Pending
175                }
176            })();
177
178            let (icon, needs_confirmation) =
179                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
180                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
181                } else {
182                    (IconName::Cog, false)
183                };
184
185            tool_uses.push(ToolUse {
186                id: tool_use.id.clone(),
187                name: tool_use.name.clone().into(),
188                ui_text: self.tool_ui_label(
189                    &tool_use.name,
190                    &tool_use.input,
191                    tool_use.is_input_complete,
192                    cx,
193                ),
194                input: tool_use.input.clone(),
195                status,
196                icon,
197                needs_confirmation,
198            })
199        }
200
201        tool_uses
202    }
203
204    pub fn tool_ui_label(
205        &self,
206        tool_name: &str,
207        input: &serde_json::Value,
208        is_input_complete: bool,
209        cx: &App,
210    ) -> SharedString {
211        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
212            if is_input_complete {
213                tool.ui_text(input).into()
214            } else {
215                tool.still_streaming_ui_text(input).into()
216            }
217        } else {
218            format!("Unknown tool {tool_name:?}").into()
219        }
220    }
221
222    pub fn tool_results_for_message(
223        &self,
224        assistant_message_id: MessageId,
225    ) -> Vec<&LanguageModelToolResult> {
226        let Some(tool_uses) = self
227            .tool_uses_by_assistant_message
228            .get(&assistant_message_id)
229        else {
230            return Vec::new();
231        };
232
233        tool_uses
234            .iter()
235            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
236            .collect()
237    }
238
239    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
240        self.tool_uses_by_assistant_message
241            .get(&assistant_message_id)
242            .map_or(false, |results| !results.is_empty())
243    }
244
245    pub fn tool_result(
246        &self,
247        tool_use_id: &LanguageModelToolUseId,
248    ) -> Option<&LanguageModelToolResult> {
249        self.tool_results.get(tool_use_id)
250    }
251
252    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
253        self.tool_result_cards.get(tool_use_id)
254    }
255
256    pub fn insert_tool_result_card(
257        &mut self,
258        tool_use_id: LanguageModelToolUseId,
259        card: AnyToolCard,
260    ) {
261        self.tool_result_cards.insert(tool_use_id, card);
262    }
263
264    pub fn request_tool_use(
265        &mut self,
266        assistant_message_id: MessageId,
267        tool_use: LanguageModelToolUse,
268        metadata: ToolUseMetadata,
269        cx: &App,
270    ) -> Arc<str> {
271        let tool_uses = self
272            .tool_uses_by_assistant_message
273            .entry(assistant_message_id)
274            .or_default();
275
276        let mut existing_tool_use_found = false;
277
278        for existing_tool_use in tool_uses.iter_mut() {
279            if existing_tool_use.id == tool_use.id {
280                *existing_tool_use = tool_use.clone();
281                existing_tool_use_found = true;
282            }
283        }
284
285        if !existing_tool_use_found {
286            tool_uses.push(tool_use.clone());
287        }
288
289        let status = if tool_use.is_input_complete {
290            self.tool_use_metadata_by_id
291                .insert(tool_use.id.clone(), metadata);
292
293            PendingToolUseStatus::Idle
294        } else {
295            PendingToolUseStatus::InputStillStreaming
296        };
297
298        let ui_text: Arc<str> = self
299            .tool_ui_label(
300                &tool_use.name,
301                &tool_use.input,
302                tool_use.is_input_complete,
303                cx,
304            )
305            .into();
306
307        self.pending_tool_uses_by_id.insert(
308            tool_use.id.clone(),
309            PendingToolUse {
310                assistant_message_id,
311                id: tool_use.id,
312                name: tool_use.name.clone(),
313                ui_text: ui_text.clone(),
314                input: tool_use.input,
315                status,
316            },
317        );
318
319        ui_text
320    }
321
322    pub fn run_pending_tool(
323        &mut self,
324        tool_use_id: LanguageModelToolUseId,
325        ui_text: SharedString,
326        task: Task<()>,
327    ) {
328        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
329            tool_use.ui_text = ui_text.into();
330            tool_use.status = PendingToolUseStatus::Running {
331                _task: task.shared(),
332            };
333        }
334    }
335
336    pub fn confirm_tool_use(
337        &mut self,
338        tool_use_id: LanguageModelToolUseId,
339        ui_text: impl Into<Arc<str>>,
340        input: serde_json::Value,
341        messages: Arc<Vec<LanguageModelRequestMessage>>,
342        tool: Arc<dyn Tool>,
343    ) {
344        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
345            let ui_text = ui_text.into();
346            tool_use.ui_text = ui_text.clone();
347            let confirmation = Confirmation {
348                tool_use_id,
349                input,
350                messages,
351                tool,
352                ui_text,
353            };
354            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
355        }
356    }
357
358    pub fn insert_tool_output(
359        &mut self,
360        tool_use_id: LanguageModelToolUseId,
361        tool_name: Arc<str>,
362        output: Result<String>,
363        configured_model: Option<&ConfiguredModel>,
364    ) -> Option<PendingToolUse> {
365        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
366
367        telemetry::event!(
368            "Agent Tool Finished",
369            model = metadata
370                .as_ref()
371                .map(|metadata| metadata.model.telemetry_id()),
372            model_provider = metadata
373                .as_ref()
374                .map(|metadata| metadata.model.provider_id().to_string()),
375            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
376            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
377            tool_name,
378            success = output.is_ok()
379        );
380
381        match output {
382            Ok(tool_result) => {
383                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
384
385                // Protect from clearly large output
386                let tool_output_limit = configured_model
387                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
388                    .unwrap_or(usize::MAX);
389
390                let tool_result = if tool_result.len() <= tool_output_limit {
391                    tool_result
392                } else {
393                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
394
395                    format!(
396                        "Tool result too long. The first {} bytes:\n\n{}",
397                        truncated.len(),
398                        truncated
399                    )
400                };
401
402                self.tool_results.insert(
403                    tool_use_id.clone(),
404                    LanguageModelToolResult {
405                        tool_use_id: tool_use_id.clone(),
406                        tool_name,
407                        content: tool_result.into(),
408                        is_error: false,
409                    },
410                );
411                self.pending_tool_uses_by_id.remove(&tool_use_id)
412            }
413            Err(err) => {
414                self.tool_results.insert(
415                    tool_use_id.clone(),
416                    LanguageModelToolResult {
417                        tool_use_id: tool_use_id.clone(),
418                        tool_name,
419                        content: err.to_string().into(),
420                        is_error: true,
421                    },
422                );
423
424                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
425                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
426                }
427
428                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
429            }
430        }
431    }
432
433    pub fn attach_tool_uses(
434        &self,
435        message_id: MessageId,
436        request_message: &mut LanguageModelRequestMessage,
437    ) {
438        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
439            for tool_use in tool_uses {
440                if self.tool_results.contains_key(&tool_use.id) {
441                    // Do not send tool uses until they are completed
442                    request_message
443                        .content
444                        .push(MessageContent::ToolUse(tool_use.clone()));
445                } else {
446                    log::debug!(
447                        "skipped tool use {:?} because it is still pending",
448                        tool_use
449                    );
450                }
451            }
452        }
453    }
454
455    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
456        self.tool_uses_by_assistant_message
457            .contains_key(&assistant_message_id)
458    }
459
460    pub fn tool_results_message(
461        &self,
462        assistant_message_id: MessageId,
463    ) -> Option<LanguageModelRequestMessage> {
464        let tool_uses = self
465            .tool_uses_by_assistant_message
466            .get(&assistant_message_id)?;
467
468        if tool_uses.is_empty() {
469            return None;
470        }
471
472        let mut request_message = LanguageModelRequestMessage {
473            role: Role::User,
474            content: vec![],
475            cache: false,
476        };
477
478        for tool_use in tool_uses {
479            if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
480                request_message
481                    .content
482                    .push(MessageContent::ToolResult(LanguageModelToolResult {
483                        tool_use_id: tool_use.id.clone(),
484                        tool_name: tool_result.tool_name.clone(),
485                        is_error: tool_result.is_error,
486                        content: if tool_result.content.is_empty() {
487                            // Surprisingly, the API fails if we return an empty string here.
488                            // It thinks we are sending a tool use without a tool result.
489                            "<Tool returned an empty string>".into()
490                        } else {
491                            tool_result.content.clone()
492                        },
493                    }));
494            }
495        }
496
497        Some(request_message)
498    }
499}
500
501#[derive(Debug, Clone)]
502pub struct PendingToolUse {
503    pub id: LanguageModelToolUseId,
504    /// The ID of the Assistant message in which the tool use was requested.
505    #[allow(unused)]
506    pub assistant_message_id: MessageId,
507    pub name: Arc<str>,
508    pub ui_text: Arc<str>,
509    pub input: serde_json::Value,
510    pub status: PendingToolUseStatus,
511}
512
513#[derive(Debug, Clone)]
514pub struct Confirmation {
515    pub tool_use_id: LanguageModelToolUseId,
516    pub input: serde_json::Value,
517    pub ui_text: Arc<str>,
518    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
519    pub tool: Arc<dyn Tool>,
520}
521
522#[derive(Debug, Clone)]
523pub enum PendingToolUseStatus {
524    InputStillStreaming,
525    Idle,
526    NeedsConfirmation(Arc<Confirmation>),
527    Running { _task: Shared<Task<()>> },
528    Error(#[allow(unused)] Arc<str>),
529}
530
531impl PendingToolUseStatus {
532    pub fn is_idle(&self) -> bool {
533        matches!(self, PendingToolUseStatus::Idle)
534    }
535
536    pub fn is_error(&self) -> bool {
537        matches!(self, PendingToolUseStatus::Error(_))
538    }
539
540    pub fn needs_confirmation(&self) -> bool {
541        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
542    }
543}
544
545#[derive(Clone)]
546pub struct ToolUseMetadata {
547    pub model: Arc<dyn LanguageModel>,
548    pub thread_id: ThreadId,
549    pub prompt_id: PromptId,
550}