tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, Role,
 12};
 13use project::Project;
 14use ui::{IconName, Window};
 15use util::truncate_lines_to_byte_limit;
 16
 17use crate::thread::{MessageId, PromptId, ThreadId};
 18use crate::thread_store::SerializedMessage;
 19
 20#[derive(Debug)]
 21pub struct ToolUse {
 22    pub id: LanguageModelToolUseId,
 23    pub name: SharedString,
 24    pub ui_text: SharedString,
 25    pub status: ToolUseStatus,
 26    pub input: serde_json::Value,
 27    pub icon: ui::IconName,
 28    pub needs_confirmation: bool,
 29}
 30
 31pub struct ToolUseState {
 32    tools: Entity<ToolWorkingSet>,
 33    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 34    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 35    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 36    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 37    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 38}
 39
 40impl ToolUseState {
 41    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 42        Self {
 43            tools,
 44            tool_uses_by_assistant_message: HashMap::default(),
 45            tool_results: HashMap::default(),
 46            pending_tool_uses_by_id: HashMap::default(),
 47            tool_result_cards: HashMap::default(),
 48            tool_use_metadata_by_id: HashMap::default(),
 49        }
 50    }
 51
 52    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 53    ///
 54    /// Accepts a function to filter the tools that should be used to populate the state.
 55    pub fn from_serialized_messages(
 56        tools: Entity<ToolWorkingSet>,
 57        messages: &[SerializedMessage],
 58        project: Entity<Project>,
 59        window: &mut Window,
 60        cx: &mut App,
 61    ) -> Self {
 62        let mut this = Self::new(tools);
 63        let mut tool_names_by_id = HashMap::default();
 64
 65        for message in messages {
 66            match message.role {
 67                Role::Assistant => {
 68                    if !message.tool_uses.is_empty() {
 69                        let tool_uses = message
 70                            .tool_uses
 71                            .iter()
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                raw_input: tool_use.input.to_string(),
 76                                input: tool_use.input.clone(),
 77                                is_input_complete: true,
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        tool_names_by_id.extend(
 82                            tool_uses
 83                                .iter()
 84                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 85                        );
 86
 87                        this.tool_uses_by_assistant_message
 88                            .insert(message.id, tool_uses);
 89
 90                        for tool_result in &message.tool_results {
 91                            let tool_use_id = tool_result.tool_use_id.clone();
 92                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 93                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 94                                continue;
 95                            };
 96
 97                            this.tool_results.insert(
 98                                tool_use_id.clone(),
 99                                LanguageModelToolResult {
100                                    tool_use_id: tool_use_id.clone(),
101                                    tool_name: tool_use.clone(),
102                                    is_error: tool_result.is_error,
103                                    content: tool_result.content.clone(),
104                                    output: tool_result.output.clone(),
105                                },
106                            );
107
108                            if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
109                                if let Some(output) = tool_result.output.clone() {
110                                    if let Some(card) =
111                                        tool.deserialize_card(output, project.clone(), window, cx)
112                                    {
113                                        this.tool_result_cards.insert(tool_use_id, card);
114                                    }
115                                }
116                            }
117                        }
118                    }
119                }
120                Role::System | Role::User => {}
121            }
122        }
123
124        this
125    }
126
127    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128        let mut cancelled_tool_uses = Vec::new();
129        self.pending_tool_uses_by_id
130            .retain(|tool_use_id, tool_use| {
131                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
132                    return true;
133                }
134
135                let content = "Tool canceled by user".into();
136                self.tool_results.insert(
137                    tool_use_id.clone(),
138                    LanguageModelToolResult {
139                        tool_use_id: tool_use_id.clone(),
140                        tool_name: tool_use.name.clone(),
141                        content,
142                        output: None,
143                        is_error: true,
144                    },
145                );
146                cancelled_tool_uses.push(tool_use.clone());
147                false
148            });
149        cancelled_tool_uses
150    }
151
152    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
153        self.pending_tool_uses_by_id.values().collect()
154    }
155
156    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
157        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
158            return Vec::new();
159        };
160
161        let mut tool_uses = Vec::new();
162
163        for tool_use in tool_uses_for_message.iter() {
164            let tool_result = self.tool_results.get(&tool_use.id);
165
166            let status = (|| {
167                if let Some(tool_result) = tool_result {
168                    return if tool_result.is_error {
169                        ToolUseStatus::Error(tool_result.content.clone().into())
170                    } else {
171                        ToolUseStatus::Finished(tool_result.content.clone().into())
172                    };
173                }
174
175                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
176                    match pending_tool_use.status {
177                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
178                        PendingToolUseStatus::NeedsConfirmation { .. } => {
179                            ToolUseStatus::NeedsConfirmation
180                        }
181                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
182                        PendingToolUseStatus::Error(ref err) => {
183                            ToolUseStatus::Error(err.clone().into())
184                        }
185                        PendingToolUseStatus::InputStillStreaming => {
186                            ToolUseStatus::InputStillStreaming
187                        }
188                    }
189                } else {
190                    ToolUseStatus::Pending
191                }
192            })();
193
194            let (icon, needs_confirmation) =
195                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
196                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
197                } else {
198                    (IconName::Cog, false)
199                };
200
201            tool_uses.push(ToolUse {
202                id: tool_use.id.clone(),
203                name: tool_use.name.clone().into(),
204                ui_text: self.tool_ui_label(
205                    &tool_use.name,
206                    &tool_use.input,
207                    tool_use.is_input_complete,
208                    cx,
209                ),
210                input: tool_use.input.clone(),
211                status,
212                icon,
213                needs_confirmation,
214            })
215        }
216
217        tool_uses
218    }
219
220    pub fn tool_ui_label(
221        &self,
222        tool_name: &str,
223        input: &serde_json::Value,
224        is_input_complete: bool,
225        cx: &App,
226    ) -> SharedString {
227        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
228            if is_input_complete {
229                tool.ui_text(input).into()
230            } else {
231                tool.still_streaming_ui_text(input).into()
232            }
233        } else {
234            format!("Unknown tool {tool_name:?}").into()
235        }
236    }
237
238    pub fn tool_results_for_message(
239        &self,
240        assistant_message_id: MessageId,
241    ) -> Vec<&LanguageModelToolResult> {
242        let Some(tool_uses) = self
243            .tool_uses_by_assistant_message
244            .get(&assistant_message_id)
245        else {
246            return Vec::new();
247        };
248
249        tool_uses
250            .iter()
251            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
252            .collect()
253    }
254
255    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
256        self.tool_uses_by_assistant_message
257            .get(&assistant_message_id)
258            .map_or(false, |results| !results.is_empty())
259    }
260
261    pub fn tool_result(
262        &self,
263        tool_use_id: &LanguageModelToolUseId,
264    ) -> Option<&LanguageModelToolResult> {
265        self.tool_results.get(tool_use_id)
266    }
267
268    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
269        self.tool_result_cards.get(tool_use_id)
270    }
271
272    pub fn insert_tool_result_card(
273        &mut self,
274        tool_use_id: LanguageModelToolUseId,
275        card: AnyToolCard,
276    ) {
277        self.tool_result_cards.insert(tool_use_id, card);
278    }
279
280    pub fn request_tool_use(
281        &mut self,
282        assistant_message_id: MessageId,
283        tool_use: LanguageModelToolUse,
284        metadata: ToolUseMetadata,
285        cx: &App,
286    ) -> Arc<str> {
287        let tool_uses = self
288            .tool_uses_by_assistant_message
289            .entry(assistant_message_id)
290            .or_default();
291
292        let mut existing_tool_use_found = false;
293
294        for existing_tool_use in tool_uses.iter_mut() {
295            if existing_tool_use.id == tool_use.id {
296                *existing_tool_use = tool_use.clone();
297                existing_tool_use_found = true;
298            }
299        }
300
301        if !existing_tool_use_found {
302            tool_uses.push(tool_use.clone());
303        }
304
305        let status = if tool_use.is_input_complete {
306            self.tool_use_metadata_by_id
307                .insert(tool_use.id.clone(), metadata);
308
309            PendingToolUseStatus::Idle
310        } else {
311            PendingToolUseStatus::InputStillStreaming
312        };
313
314        let ui_text: Arc<str> = self
315            .tool_ui_label(
316                &tool_use.name,
317                &tool_use.input,
318                tool_use.is_input_complete,
319                cx,
320            )
321            .into();
322
323        self.pending_tool_uses_by_id.insert(
324            tool_use.id.clone(),
325            PendingToolUse {
326                assistant_message_id,
327                id: tool_use.id,
328                name: tool_use.name.clone(),
329                ui_text: ui_text.clone(),
330                input: tool_use.input,
331                status,
332            },
333        );
334
335        ui_text
336    }
337
338    pub fn run_pending_tool(
339        &mut self,
340        tool_use_id: LanguageModelToolUseId,
341        ui_text: SharedString,
342        task: Task<()>,
343    ) {
344        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
345            tool_use.ui_text = ui_text.into();
346            tool_use.status = PendingToolUseStatus::Running {
347                _task: task.shared(),
348            };
349        }
350    }
351
352    pub fn confirm_tool_use(
353        &mut self,
354        tool_use_id: LanguageModelToolUseId,
355        ui_text: impl Into<Arc<str>>,
356        input: serde_json::Value,
357        request: Arc<LanguageModelRequest>,
358        tool: Arc<dyn Tool>,
359    ) {
360        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
361            let ui_text = ui_text.into();
362            tool_use.ui_text = ui_text.clone();
363            let confirmation = Confirmation {
364                tool_use_id,
365                input,
366                request,
367                tool,
368                ui_text,
369            };
370            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
371        }
372    }
373
374    pub fn insert_tool_output(
375        &mut self,
376        tool_use_id: LanguageModelToolUseId,
377        tool_name: Arc<str>,
378        output: Result<ToolResultOutput>,
379        configured_model: Option<&ConfiguredModel>,
380    ) -> Option<PendingToolUse> {
381        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
382
383        telemetry::event!(
384            "Agent Tool Finished",
385            model = metadata
386                .as_ref()
387                .map(|metadata| metadata.model.telemetry_id()),
388            model_provider = metadata
389                .as_ref()
390                .map(|metadata| metadata.model.provider_id().to_string()),
391            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
392            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
393            tool_name,
394            success = output.is_ok()
395        );
396
397        match output {
398            Ok(output) => {
399                let tool_result = output.content;
400                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
401
402                // Protect from clearly large output
403                let tool_output_limit = configured_model
404                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
405                    .unwrap_or(usize::MAX);
406
407                let tool_result = if tool_result.len() <= tool_output_limit {
408                    tool_result
409                } else {
410                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
411
412                    format!(
413                        "Tool result too long. The first {} bytes:\n\n{}",
414                        truncated.len(),
415                        truncated
416                    )
417                };
418
419                self.tool_results.insert(
420                    tool_use_id.clone(),
421                    LanguageModelToolResult {
422                        tool_use_id: tool_use_id.clone(),
423                        tool_name,
424                        content: tool_result.into(),
425                        is_error: false,
426                        output: output.output,
427                    },
428                );
429                self.pending_tool_uses_by_id.remove(&tool_use_id)
430            }
431            Err(err) => {
432                self.tool_results.insert(
433                    tool_use_id.clone(),
434                    LanguageModelToolResult {
435                        tool_use_id: tool_use_id.clone(),
436                        tool_name,
437                        content: err.to_string().into(),
438                        is_error: true,
439                        output: None,
440                    },
441                );
442
443                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
444                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
445                }
446
447                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
448            }
449        }
450    }
451
452    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
453        self.tool_uses_by_assistant_message
454            .contains_key(&assistant_message_id)
455    }
456
457    pub fn tool_results(
458        &self,
459        assistant_message_id: MessageId,
460    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
461        self.tool_uses_by_assistant_message
462            .get(&assistant_message_id)
463            .into_iter()
464            .flatten()
465            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
466    }
467}
468
469#[derive(Debug, Clone)]
470pub struct PendingToolUse {
471    pub id: LanguageModelToolUseId,
472    /// The ID of the Assistant message in which the tool use was requested.
473    #[allow(unused)]
474    pub assistant_message_id: MessageId,
475    pub name: Arc<str>,
476    pub ui_text: Arc<str>,
477    pub input: serde_json::Value,
478    pub status: PendingToolUseStatus,
479}
480
481#[derive(Debug, Clone)]
482pub struct Confirmation {
483    pub tool_use_id: LanguageModelToolUseId,
484    pub input: serde_json::Value,
485    pub ui_text: Arc<str>,
486    pub request: Arc<LanguageModelRequest>,
487    pub tool: Arc<dyn Tool>,
488}
489
490#[derive(Debug, Clone)]
491pub enum PendingToolUseStatus {
492    InputStillStreaming,
493    Idle,
494    NeedsConfirmation(Arc<Confirmation>),
495    Running { _task: Shared<Task<()>> },
496    Error(#[allow(unused)] Arc<str>),
497}
498
499impl PendingToolUseStatus {
500    pub fn is_idle(&self) -> bool {
501        matches!(self, PendingToolUseStatus::Idle)
502    }
503
504    pub fn is_error(&self) -> bool {
505        matches!(self, PendingToolUseStatus::Error(_))
506    }
507
508    pub fn needs_confirmation(&self) -> bool {
509        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
510    }
511}
512
513#[derive(Clone)]
514pub struct ToolUseMetadata {
515    pub model: Arc<dyn LanguageModel>,
516    pub thread_id: ThreadId,
517    pub prompt_id: PromptId,
518}