tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::MessageId;
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub enum ToolUseStatus {
 32    NeedsConfirmation,
 33    Pending,
 34    Running,
 35    Finished(SharedString),
 36    Error(SharedString),
 37}
 38
 39impl ToolUseStatus {
 40    pub fn text(&self) -> SharedString {
 41        match self {
 42            ToolUseStatus::NeedsConfirmation => "".into(),
 43            ToolUseStatus::Pending => "".into(),
 44            ToolUseStatus::Running => "".into(),
 45            ToolUseStatus::Finished(out) => out.clone(),
 46            ToolUseStatus::Error(out) => out.clone(),
 47        }
 48    }
 49}
 50
 51pub struct ToolUseState {
 52    tools: Entity<ToolWorkingSet>,
 53    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 54    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 55    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 56    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 57}
 58
 59pub const USING_TOOL_MARKER: &str = "<using_tool>";
 60
 61impl ToolUseState {
 62    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 63        Self {
 64            tools,
 65            tool_uses_by_assistant_message: HashMap::default(),
 66            tool_uses_by_user_message: HashMap::default(),
 67            tool_results: HashMap::default(),
 68            pending_tool_uses_by_id: HashMap::default(),
 69        }
 70    }
 71
 72    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 73    ///
 74    /// Accepts a function to filter the tools that should be used to populate the state.
 75    pub fn from_serialized_messages(
 76        tools: Entity<ToolWorkingSet>,
 77        messages: &[SerializedMessage],
 78        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 79    ) -> Self {
 80        let mut this = Self::new(tools);
 81        let mut tool_names_by_id = HashMap::default();
 82
 83        for message in messages {
 84            match message.role {
 85                Role::Assistant => {
 86                    if !message.tool_uses.is_empty() {
 87                        let tool_uses = message
 88                            .tool_uses
 89                            .iter()
 90                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 91                            .map(|tool_use| LanguageModelToolUse {
 92                                id: tool_use.id.clone(),
 93                                name: tool_use.name.clone().into(),
 94                                input: tool_use.input.clone(),
 95                            })
 96                            .collect::<Vec<_>>();
 97
 98                        tool_names_by_id.extend(
 99                            tool_uses
100                                .iter()
101                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
102                        );
103
104                        this.tool_uses_by_assistant_message
105                            .insert(message.id, tool_uses);
106                    }
107                }
108                Role::User => {
109                    if !message.tool_results.is_empty() {
110                        let tool_uses_by_user_message = this
111                            .tool_uses_by_user_message
112                            .entry(message.id)
113                            .or_default();
114
115                        for tool_result in &message.tool_results {
116                            let tool_use_id = tool_result.tool_use_id.clone();
117                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
118                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
119                                continue;
120                            };
121
122                            if !(filter_by_tool_name)(tool_use.as_ref()) {
123                                continue;
124                            }
125
126                            tool_uses_by_user_message.push(tool_use_id.clone());
127                            this.tool_results.insert(
128                                tool_use_id.clone(),
129                                LanguageModelToolResult {
130                                    tool_use_id,
131                                    tool_name: tool_use.clone(),
132                                    is_error: tool_result.is_error,
133                                    content: tool_result.content.clone(),
134                                },
135                            );
136                        }
137                    }
138                }
139                Role::System => {}
140            }
141        }
142
143        this
144    }
145
146    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
147        let mut pending_tools = Vec::new();
148        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
149            self.tool_results.insert(
150                tool_use_id.clone(),
151                LanguageModelToolResult {
152                    tool_use_id,
153                    tool_name: tool_use.name.clone(),
154                    content: "Tool canceled by user".into(),
155                    is_error: true,
156                },
157            );
158            pending_tools.push(tool_use.clone());
159        }
160        pending_tools
161    }
162
163    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164        self.pending_tool_uses_by_id.values().collect()
165    }
166
167    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169            return Vec::new();
170        };
171
172        let mut tool_uses = Vec::new();
173
174        for tool_use in tool_uses_for_message.iter() {
175            let tool_result = self.tool_results.get(&tool_use.id);
176
177            let status = (|| {
178                if let Some(tool_result) = tool_result {
179                    return if tool_result.is_error {
180                        ToolUseStatus::Error(tool_result.content.clone().into())
181                    } else {
182                        ToolUseStatus::Finished(tool_result.content.clone().into())
183                    };
184                }
185
186                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
187                    match pending_tool_use.status {
188                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
189                        PendingToolUseStatus::NeedsConfirmation { .. } => {
190                            ToolUseStatus::NeedsConfirmation
191                        }
192                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
193                        PendingToolUseStatus::Error(ref err) => {
194                            ToolUseStatus::Error(err.clone().into())
195                        }
196                    }
197                } else {
198                    ToolUseStatus::Pending
199                }
200            })();
201
202            let (icon, needs_confirmation) =
203                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
204                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
205                } else {
206                    (IconName::Cog, false)
207                };
208
209            tool_uses.push(ToolUse {
210                id: tool_use.id.clone(),
211                name: tool_use.name.clone().into(),
212                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
213                input: tool_use.input.clone(),
214                status,
215                icon,
216                needs_confirmation,
217            })
218        }
219
220        tool_uses
221    }
222
223    pub fn tool_ui_label(
224        &self,
225        tool_name: &str,
226        input: &serde_json::Value,
227        cx: &App,
228    ) -> SharedString {
229        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
230            tool.ui_text(input).into()
231        } else {
232            format!("Unknown tool {tool_name:?}").into()
233        }
234    }
235
236    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
237        let empty = Vec::new();
238
239        self.tool_uses_by_user_message
240            .get(&message_id)
241            .unwrap_or(&empty)
242            .iter()
243            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
244            .collect()
245    }
246
247    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
248        self.tool_uses_by_user_message
249            .get(&message_id)
250            .map_or(false, |results| !results.is_empty())
251    }
252
253    pub fn tool_result(
254        &self,
255        tool_use_id: &LanguageModelToolUseId,
256    ) -> Option<&LanguageModelToolResult> {
257        self.tool_results.get(tool_use_id)
258    }
259
260    pub fn request_tool_use(
261        &mut self,
262        assistant_message_id: MessageId,
263        tool_use: LanguageModelToolUse,
264        cx: &App,
265    ) {
266        self.tool_uses_by_assistant_message
267            .entry(assistant_message_id)
268            .or_default()
269            .push(tool_use.clone());
270
271        // The tool use is being requested by the Assistant, so we want to
272        // attach the tool results to the next user message.
273        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
274        self.tool_uses_by_user_message
275            .entry(next_user_message_id)
276            .or_default()
277            .push(tool_use.id.clone());
278
279        self.pending_tool_uses_by_id.insert(
280            tool_use.id.clone(),
281            PendingToolUse {
282                assistant_message_id,
283                id: tool_use.id,
284                name: tool_use.name.clone(),
285                ui_text: self
286                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
287                    .into(),
288                input: tool_use.input,
289                status: PendingToolUseStatus::Idle,
290            },
291        );
292    }
293
294    pub fn run_pending_tool(
295        &mut self,
296        tool_use_id: LanguageModelToolUseId,
297        ui_text: SharedString,
298        task: Task<()>,
299    ) {
300        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301            tool_use.ui_text = ui_text.into();
302            tool_use.status = PendingToolUseStatus::Running {
303                _task: task.shared(),
304            };
305        }
306    }
307
308    pub fn confirm_tool_use(
309        &mut self,
310        tool_use_id: LanguageModelToolUseId,
311        ui_text: impl Into<Arc<str>>,
312        input: serde_json::Value,
313        messages: Arc<Vec<LanguageModelRequestMessage>>,
314        tool: Arc<dyn Tool>,
315    ) {
316        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
317            let ui_text = ui_text.into();
318            tool_use.ui_text = ui_text.clone();
319            let confirmation = Confirmation {
320                tool_use_id,
321                input,
322                messages,
323                tool,
324                ui_text,
325            };
326            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
327        }
328    }
329
330    pub fn insert_tool_output(
331        &mut self,
332        tool_use_id: LanguageModelToolUseId,
333        tool_name: Arc<str>,
334        output: Result<String>,
335        cx: &App,
336    ) -> Option<PendingToolUse> {
337        telemetry::event!("Agent Tool Finished", tool_name, success = output.is_ok());
338
339        match output {
340            Ok(tool_result) => {
341                let model_registry = LanguageModelRegistry::read_global(cx);
342
343                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
344
345                // Protect from clearly large output
346                let tool_output_limit = model_registry
347                    .default_model()
348                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
349                    .unwrap_or(usize::MAX);
350
351                let tool_result = if tool_result.len() <= tool_output_limit {
352                    tool_result
353                } else {
354                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
355
356                    format!(
357                        "Tool result too long. The first {} bytes:\n\n{}",
358                        truncated.len(),
359                        truncated
360                    )
361                };
362
363                self.tool_results.insert(
364                    tool_use_id.clone(),
365                    LanguageModelToolResult {
366                        tool_use_id: tool_use_id.clone(),
367                        tool_name,
368                        content: tool_result.into(),
369                        is_error: false,
370                    },
371                );
372                self.pending_tool_uses_by_id.remove(&tool_use_id)
373            }
374            Err(err) => {
375                self.tool_results.insert(
376                    tool_use_id.clone(),
377                    LanguageModelToolResult {
378                        tool_use_id: tool_use_id.clone(),
379                        tool_name,
380                        content: err.to_string().into(),
381                        is_error: true,
382                    },
383                );
384
385                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
386                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
387                }
388
389                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
390            }
391        }
392    }
393
394    pub fn attach_tool_uses(
395        &self,
396        message_id: MessageId,
397        request_message: &mut LanguageModelRequestMessage,
398    ) {
399        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
400            let mut found_tool_use = false;
401
402            for tool_use in tool_uses {
403                if self.tool_results.contains_key(&tool_use.id) {
404                    if !found_tool_use {
405                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
406                        match request_message.content.last_mut() {
407                            Some(MessageContent::Text(txt)) => {
408                                if txt.is_empty() {
409                                    txt.push_str(USING_TOOL_MARKER);
410                                }
411                            }
412                            None | Some(_) => {
413                                request_message
414                                    .content
415                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
416                            }
417                        };
418                    }
419
420                    found_tool_use = true;
421
422                    // Do not send tool uses until they are completed
423                    request_message
424                        .content
425                        .push(MessageContent::ToolUse(tool_use.clone()));
426                } else {
427                    log::debug!(
428                        "skipped tool use {:?} because it is still pending",
429                        tool_use
430                    );
431                }
432            }
433        }
434    }
435
436    pub fn attach_tool_results(
437        &self,
438        message_id: MessageId,
439        request_message: &mut LanguageModelRequestMessage,
440    ) {
441        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
442            for tool_use_id in tool_uses {
443                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
444                    request_message.content.push(MessageContent::ToolResult(
445                        LanguageModelToolResult {
446                            tool_use_id: tool_use_id.clone(),
447                            tool_name: tool_result.tool_name.clone(),
448                            is_error: tool_result.is_error,
449                            content: if tool_result.content.is_empty() {
450                                // Surprisingly, the API fails if we return an empty string here.
451                                // It thinks we are sending a tool use without a tool result.
452                                "<Tool returned an empty string>".into()
453                            } else {
454                                tool_result.content.clone()
455                            },
456                        },
457                    ));
458                }
459            }
460        }
461    }
462}
463
464#[derive(Debug, Clone)]
465pub struct PendingToolUse {
466    pub id: LanguageModelToolUseId,
467    /// The ID of the Assistant message in which the tool use was requested.
468    #[allow(unused)]
469    pub assistant_message_id: MessageId,
470    pub name: Arc<str>,
471    pub ui_text: Arc<str>,
472    pub input: serde_json::Value,
473    pub status: PendingToolUseStatus,
474}
475
476#[derive(Debug, Clone)]
477pub struct Confirmation {
478    pub tool_use_id: LanguageModelToolUseId,
479    pub input: serde_json::Value,
480    pub ui_text: Arc<str>,
481    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
482    pub tool: Arc<dyn Tool>,
483}
484
485#[derive(Debug, Clone)]
486pub enum PendingToolUseStatus {
487    Idle,
488    NeedsConfirmation(Arc<Confirmation>),
489    Running { _task: Shared<Task<()>> },
490    Error(#[allow(unused)] Arc<str>),
491}
492
493impl PendingToolUseStatus {
494    pub fn is_idle(&self) -> bool {
495        matches!(self, PendingToolUseStatus::Idle)
496    }
497
498    pub fn is_error(&self) -> bool {
499        matches!(self, PendingToolUseStatus::Error(_))
500    }
501
502    pub fn needs_confirmation(&self) -> bool {
503        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
504    }
505}