tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 34    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 35    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 36    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 37}
 38
 39impl ToolUseState {
 40    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 41        Self {
 42            tools,
 43            tool_uses_by_assistant_message: HashMap::default(),
 44            tool_results: HashMap::default(),
 45            pending_tool_uses_by_id: HashMap::default(),
 46            tool_result_cards: HashMap::default(),
 47            tool_use_metadata_by_id: HashMap::default(),
 48        }
 49    }
 50
 51    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 52    ///
 53    /// Accepts a function to filter the tools that should be used to populate the state.
 54    pub fn from_serialized_messages(
 55        tools: Entity<ToolWorkingSet>,
 56        messages: &[SerializedMessage],
 57    ) -> Self {
 58        let mut this = Self::new(tools);
 59        let mut tool_names_by_id = HashMap::default();
 60
 61        for message in messages {
 62            match message.role {
 63                Role::Assistant => {
 64                    if !message.tool_uses.is_empty() {
 65                        let tool_uses = message
 66                            .tool_uses
 67                            .iter()
 68                            .map(|tool_use| LanguageModelToolUse {
 69                                id: tool_use.id.clone(),
 70                                name: tool_use.name.clone().into(),
 71                                input: tool_use.input.clone(),
 72                                is_input_complete: true,
 73                            })
 74                            .collect::<Vec<_>>();
 75
 76                        tool_names_by_id.extend(
 77                            tool_uses
 78                                .iter()
 79                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 80                        );
 81
 82                        this.tool_uses_by_assistant_message
 83                            .insert(message.id, tool_uses);
 84
 85                        for tool_result in &message.tool_results {
 86                            let tool_use_id = tool_result.tool_use_id.clone();
 87                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 88                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 89                                continue;
 90                            };
 91
 92                            this.tool_results.insert(
 93                                tool_use_id.clone(),
 94                                LanguageModelToolResult {
 95                                    tool_use_id,
 96                                    tool_name: tool_use.clone(),
 97                                    is_error: tool_result.is_error,
 98                                    content: tool_result.content.clone(),
 99                                },
100                            );
101                        }
102                    }
103                }
104                Role::System | Role::User => {}
105            }
106        }
107
108        this
109    }
110
111    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
112        let mut pending_tools = Vec::new();
113        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
114            self.tool_results.insert(
115                tool_use_id.clone(),
116                LanguageModelToolResult {
117                    tool_use_id,
118                    tool_name: tool_use.name.clone(),
119                    content: "Tool canceled by user".into(),
120                    is_error: true,
121                },
122            );
123            pending_tools.push(tool_use.clone());
124        }
125        pending_tools
126    }
127
128    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
129        self.pending_tool_uses_by_id.values().collect()
130    }
131
132    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
133        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
134            return Vec::new();
135        };
136
137        let mut tool_uses = Vec::new();
138
139        for tool_use in tool_uses_for_message.iter() {
140            let tool_result = self.tool_results.get(&tool_use.id);
141
142            let status = (|| {
143                if let Some(tool_result) = tool_result {
144                    return if tool_result.is_error {
145                        ToolUseStatus::Error(tool_result.content.clone().into())
146                    } else {
147                        ToolUseStatus::Finished(tool_result.content.clone().into())
148                    };
149                }
150
151                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
152                    match pending_tool_use.status {
153                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
154                        PendingToolUseStatus::NeedsConfirmation { .. } => {
155                            ToolUseStatus::NeedsConfirmation
156                        }
157                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
158                        PendingToolUseStatus::Error(ref err) => {
159                            ToolUseStatus::Error(err.clone().into())
160                        }
161                        PendingToolUseStatus::InputStillStreaming => {
162                            ToolUseStatus::InputStillStreaming
163                        }
164                    }
165                } else {
166                    ToolUseStatus::Pending
167                }
168            })();
169
170            let (icon, needs_confirmation) =
171                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
172                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
173                } else {
174                    (IconName::Cog, false)
175                };
176
177            tool_uses.push(ToolUse {
178                id: tool_use.id.clone(),
179                name: tool_use.name.clone().into(),
180                ui_text: self.tool_ui_label(
181                    &tool_use.name,
182                    &tool_use.input,
183                    tool_use.is_input_complete,
184                    cx,
185                ),
186                input: tool_use.input.clone(),
187                status,
188                icon,
189                needs_confirmation,
190            })
191        }
192
193        tool_uses
194    }
195
196    pub fn tool_ui_label(
197        &self,
198        tool_name: &str,
199        input: &serde_json::Value,
200        is_input_complete: bool,
201        cx: &App,
202    ) -> SharedString {
203        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
204            if is_input_complete {
205                tool.ui_text(input).into()
206            } else {
207                tool.still_streaming_ui_text(input).into()
208            }
209        } else {
210            format!("Unknown tool {tool_name:?}").into()
211        }
212    }
213
214    pub fn tool_results_for_message(
215        &self,
216        assistant_message_id: MessageId,
217    ) -> Vec<&LanguageModelToolResult> {
218        let Some(tool_uses) = self
219            .tool_uses_by_assistant_message
220            .get(&assistant_message_id)
221        else {
222            return Vec::new();
223        };
224
225        tool_uses
226            .iter()
227            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
228            .collect()
229    }
230
231    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
232        self.tool_uses_by_assistant_message
233            .get(&assistant_message_id)
234            .map_or(false, |results| !results.is_empty())
235    }
236
237    pub fn tool_result(
238        &self,
239        tool_use_id: &LanguageModelToolUseId,
240    ) -> Option<&LanguageModelToolResult> {
241        self.tool_results.get(tool_use_id)
242    }
243
244    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
245        self.tool_result_cards.get(tool_use_id)
246    }
247
248    pub fn insert_tool_result_card(
249        &mut self,
250        tool_use_id: LanguageModelToolUseId,
251        card: AnyToolCard,
252    ) {
253        self.tool_result_cards.insert(tool_use_id, card);
254    }
255
256    pub fn request_tool_use(
257        &mut self,
258        assistant_message_id: MessageId,
259        tool_use: LanguageModelToolUse,
260        metadata: ToolUseMetadata,
261        cx: &App,
262    ) -> Arc<str> {
263        let tool_uses = self
264            .tool_uses_by_assistant_message
265            .entry(assistant_message_id)
266            .or_default();
267
268        let mut existing_tool_use_found = false;
269
270        for existing_tool_use in tool_uses.iter_mut() {
271            if existing_tool_use.id == tool_use.id {
272                *existing_tool_use = tool_use.clone();
273                existing_tool_use_found = true;
274            }
275        }
276
277        if !existing_tool_use_found {
278            tool_uses.push(tool_use.clone());
279        }
280
281        let status = if tool_use.is_input_complete {
282            self.tool_use_metadata_by_id
283                .insert(tool_use.id.clone(), metadata);
284
285            PendingToolUseStatus::Idle
286        } else {
287            PendingToolUseStatus::InputStillStreaming
288        };
289
290        let ui_text: Arc<str> = self
291            .tool_ui_label(
292                &tool_use.name,
293                &tool_use.input,
294                tool_use.is_input_complete,
295                cx,
296            )
297            .into();
298
299        self.pending_tool_uses_by_id.insert(
300            tool_use.id.clone(),
301            PendingToolUse {
302                assistant_message_id,
303                id: tool_use.id,
304                name: tool_use.name.clone(),
305                ui_text: ui_text.clone(),
306                input: tool_use.input,
307                status,
308            },
309        );
310
311        ui_text
312    }
313
314    pub fn run_pending_tool(
315        &mut self,
316        tool_use_id: LanguageModelToolUseId,
317        ui_text: SharedString,
318        task: Task<()>,
319    ) {
320        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
321            tool_use.ui_text = ui_text.into();
322            tool_use.status = PendingToolUseStatus::Running {
323                _task: task.shared(),
324            };
325        }
326    }
327
328    pub fn confirm_tool_use(
329        &mut self,
330        tool_use_id: LanguageModelToolUseId,
331        ui_text: impl Into<Arc<str>>,
332        input: serde_json::Value,
333        messages: Arc<Vec<LanguageModelRequestMessage>>,
334        tool: Arc<dyn Tool>,
335    ) {
336        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
337            let ui_text = ui_text.into();
338            tool_use.ui_text = ui_text.clone();
339            let confirmation = Confirmation {
340                tool_use_id,
341                input,
342                messages,
343                tool,
344                ui_text,
345            };
346            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
347        }
348    }
349
350    pub fn insert_tool_output(
351        &mut self,
352        tool_use_id: LanguageModelToolUseId,
353        tool_name: Arc<str>,
354        output: Result<String>,
355        cx: &App,
356    ) -> Option<PendingToolUse> {
357        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
358
359        telemetry::event!(
360            "Agent Tool Finished",
361            model = metadata
362                .as_ref()
363                .map(|metadata| metadata.model.telemetry_id()),
364            model_provider = metadata
365                .as_ref()
366                .map(|metadata| metadata.model.provider_id().to_string()),
367            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
368            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
369            tool_name,
370            success = output.is_ok()
371        );
372
373        match output {
374            Ok(tool_result) => {
375                let model_registry = LanguageModelRegistry::read_global(cx);
376
377                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
378
379                // Protect from clearly large output
380                let tool_output_limit = model_registry
381                    .default_model()
382                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
383                    .unwrap_or(usize::MAX);
384
385                let tool_result = if tool_result.len() <= tool_output_limit {
386                    tool_result
387                } else {
388                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
389
390                    format!(
391                        "Tool result too long. The first {} bytes:\n\n{}",
392                        truncated.len(),
393                        truncated
394                    )
395                };
396
397                self.tool_results.insert(
398                    tool_use_id.clone(),
399                    LanguageModelToolResult {
400                        tool_use_id: tool_use_id.clone(),
401                        tool_name,
402                        content: tool_result.into(),
403                        is_error: false,
404                    },
405                );
406                self.pending_tool_uses_by_id.remove(&tool_use_id)
407            }
408            Err(err) => {
409                self.tool_results.insert(
410                    tool_use_id.clone(),
411                    LanguageModelToolResult {
412                        tool_use_id: tool_use_id.clone(),
413                        tool_name,
414                        content: err.to_string().into(),
415                        is_error: true,
416                    },
417                );
418
419                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
420                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
421                }
422
423                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
424            }
425        }
426    }
427
428    pub fn attach_tool_uses(
429        &self,
430        message_id: MessageId,
431        request_message: &mut LanguageModelRequestMessage,
432    ) {
433        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
434            for tool_use in tool_uses {
435                if self.tool_results.contains_key(&tool_use.id) {
436                    // Do not send tool uses until they are completed
437                    request_message
438                        .content
439                        .push(MessageContent::ToolUse(tool_use.clone()));
440                } else {
441                    log::debug!(
442                        "skipped tool use {:?} because it is still pending",
443                        tool_use
444                    );
445                }
446            }
447        }
448    }
449
450    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
451        self.tool_uses_by_assistant_message
452            .contains_key(&assistant_message_id)
453    }
454
455    pub fn tool_results_message(
456        &self,
457        assistant_message_id: MessageId,
458    ) -> Option<LanguageModelRequestMessage> {
459        let tool_uses = self
460            .tool_uses_by_assistant_message
461            .get(&assistant_message_id)?;
462
463        if tool_uses.is_empty() {
464            return None;
465        }
466
467        let mut request_message = LanguageModelRequestMessage {
468            role: Role::User,
469            content: vec![],
470            cache: false,
471        };
472
473        for tool_use in tool_uses {
474            if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
475                request_message
476                    .content
477                    .push(MessageContent::ToolResult(LanguageModelToolResult {
478                        tool_use_id: tool_use.id.clone(),
479                        tool_name: tool_result.tool_name.clone(),
480                        is_error: tool_result.is_error,
481                        content: if tool_result.content.is_empty() {
482                            // Surprisingly, the API fails if we return an empty string here.
483                            // It thinks we are sending a tool use without a tool result.
484                            "<Tool returned an empty string>".into()
485                        } else {
486                            tool_result.content.clone()
487                        },
488                    }));
489            }
490        }
491
492        Some(request_message)
493    }
494}
495
496#[derive(Debug, Clone)]
497pub struct PendingToolUse {
498    pub id: LanguageModelToolUseId,
499    /// The ID of the Assistant message in which the tool use was requested.
500    #[allow(unused)]
501    pub assistant_message_id: MessageId,
502    pub name: Arc<str>,
503    pub ui_text: Arc<str>,
504    pub input: serde_json::Value,
505    pub status: PendingToolUseStatus,
506}
507
508#[derive(Debug, Clone)]
509pub struct Confirmation {
510    pub tool_use_id: LanguageModelToolUseId,
511    pub input: serde_json::Value,
512    pub ui_text: Arc<str>,
513    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
514    pub tool: Arc<dyn Tool>,
515}
516
517#[derive(Debug, Clone)]
518pub enum PendingToolUseStatus {
519    InputStillStreaming,
520    Idle,
521    NeedsConfirmation(Arc<Confirmation>),
522    Running { _task: Shared<Task<()>> },
523    Error(#[allow(unused)] Arc<str>),
524}
525
526impl PendingToolUseStatus {
527    pub fn is_idle(&self) -> bool {
528        matches!(self, PendingToolUseStatus::Idle)
529    }
530
531    pub fn is_error(&self) -> bool {
532        matches!(self, PendingToolUseStatus::Error(_))
533    }
534
535    pub fn needs_confirmation(&self) -> bool {
536        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
537    }
538}
539
540#[derive(Clone)]
541pub struct ToolUseMetadata {
542    pub model: Arc<dyn LanguageModel>,
543    pub thread_id: ThreadId,
544    pub prompt_id: PromptId,
545}