tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::FutureExt as _;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14
 15use crate::thread::MessageId;
 16use crate::thread_store::SerializedMessage;
 17
 18#[derive(Debug)]
 19pub struct ToolUse {
 20    pub id: LanguageModelToolUseId,
 21    pub name: SharedString,
 22    pub ui_text: SharedString,
 23    pub status: ToolUseStatus,
 24    pub input: serde_json::Value,
 25    pub icon: ui::IconName,
 26}
 27
 28#[derive(Debug, Clone)]
 29pub enum ToolUseStatus {
 30    NeedsConfirmation,
 31    Pending,
 32    Running,
 33    Finished(SharedString),
 34    Error(SharedString),
 35}
 36
 37pub struct ToolUseState {
 38    tools: Arc<ToolWorkingSet>,
 39    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 40    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 41    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 42    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 43}
 44
 45impl ToolUseState {
 46    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 47        Self {
 48            tools,
 49            tool_uses_by_assistant_message: HashMap::default(),
 50            tool_uses_by_user_message: HashMap::default(),
 51            tool_results: HashMap::default(),
 52            pending_tool_uses_by_id: HashMap::default(),
 53        }
 54    }
 55
 56    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 57    ///
 58    /// Accepts a function to filter the tools that should be used to populate the state.
 59    pub fn from_serialized_messages(
 60        tools: Arc<ToolWorkingSet>,
 61        messages: &[SerializedMessage],
 62        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 63    ) -> Self {
 64        let mut this = Self::new(tools);
 65        let mut tool_names_by_id = HashMap::default();
 66
 67        for message in messages {
 68            match message.role {
 69                Role::Assistant => {
 70                    if !message.tool_uses.is_empty() {
 71                        let tool_uses = message
 72                            .tool_uses
 73                            .iter()
 74                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 75                            .map(|tool_use| LanguageModelToolUse {
 76                                id: tool_use.id.clone(),
 77                                name: tool_use.name.clone().into(),
 78                                input: tool_use.input.clone(),
 79                            })
 80                            .collect::<Vec<_>>();
 81
 82                        tool_names_by_id.extend(
 83                            tool_uses
 84                                .iter()
 85                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 86                        );
 87
 88                        this.tool_uses_by_assistant_message
 89                            .insert(message.id, tool_uses);
 90                    }
 91                }
 92                Role::User => {
 93                    if !message.tool_results.is_empty() {
 94                        let tool_uses_by_user_message = this
 95                            .tool_uses_by_user_message
 96                            .entry(message.id)
 97                            .or_default();
 98
 99                        for tool_result in &message.tool_results {
100                            let tool_use_id = tool_result.tool_use_id.clone();
101                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
102                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
103                                continue;
104                            };
105
106                            if !(filter_by_tool_name)(tool_use.as_ref()) {
107                                continue;
108                            }
109
110                            tool_uses_by_user_message.push(tool_use_id.clone());
111                            this.tool_results.insert(
112                                tool_use_id.clone(),
113                                LanguageModelToolResult {
114                                    tool_use_id,
115                                    is_error: tool_result.is_error,
116                                    content: tool_result.content.clone(),
117                                },
118                            );
119                        }
120                    }
121                }
122                Role::System => {}
123            }
124        }
125
126        this
127    }
128
129    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130        let mut pending_tools = Vec::new();
131        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132            self.tool_results.insert(
133                tool_use_id.clone(),
134                LanguageModelToolResult {
135                    tool_use_id,
136                    content: "Tool canceled by user".into(),
137                    is_error: true,
138                },
139            );
140            pending_tools.push(tool_use.clone());
141        }
142        pending_tools
143    }
144
145    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
146        self.pending_tool_uses_by_id.values().collect()
147    }
148
149    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
150        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
151            return Vec::new();
152        };
153
154        let mut tool_uses = Vec::new();
155
156        for tool_use in tool_uses_for_message.iter() {
157            let tool_result = self.tool_results.get(&tool_use.id);
158
159            let status = (|| {
160                if let Some(tool_result) = tool_result {
161                    return if tool_result.is_error {
162                        ToolUseStatus::Error(tool_result.content.clone().into())
163                    } else {
164                        ToolUseStatus::Finished(tool_result.content.clone().into())
165                    };
166                }
167
168                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
169                    match pending_tool_use.status {
170                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
171                        PendingToolUseStatus::NeedsConfirmation { .. } => {
172                            ToolUseStatus::NeedsConfirmation
173                        }
174                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
175                        PendingToolUseStatus::Error(ref err) => {
176                            ToolUseStatus::Error(err.clone().into())
177                        }
178                    }
179                } else {
180                    ToolUseStatus::Pending
181                }
182            })();
183
184            let icon = if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
185                tool.icon()
186            } else {
187                IconName::Cog
188            };
189
190            tool_uses.push(ToolUse {
191                id: tool_use.id.clone(),
192                name: tool_use.name.clone().into(),
193                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
194                input: tool_use.input.clone(),
195                status,
196                icon,
197            })
198        }
199
200        tool_uses
201    }
202
203    pub fn tool_ui_label(
204        &self,
205        tool_name: &str,
206        input: &serde_json::Value,
207        cx: &App,
208    ) -> SharedString {
209        if let Some(tool) = self.tools.tool(tool_name, cx) {
210            tool.ui_text(input).into()
211        } else {
212            "Unknown tool".into()
213        }
214    }
215
216    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
217        let empty = Vec::new();
218
219        self.tool_uses_by_user_message
220            .get(&message_id)
221            .unwrap_or(&empty)
222            .iter()
223            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
224            .collect()
225    }
226
227    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
228        self.tool_uses_by_user_message
229            .get(&message_id)
230            .map_or(false, |results| !results.is_empty())
231    }
232
233    pub fn tool_result(
234        &self,
235        tool_use_id: &LanguageModelToolUseId,
236    ) -> Option<&LanguageModelToolResult> {
237        self.tool_results.get(tool_use_id)
238    }
239
240    pub fn request_tool_use(
241        &mut self,
242        assistant_message_id: MessageId,
243        tool_use: LanguageModelToolUse,
244        cx: &App,
245    ) {
246        self.tool_uses_by_assistant_message
247            .entry(assistant_message_id)
248            .or_default()
249            .push(tool_use.clone());
250
251        // The tool use is being requested by the Assistant, so we want to
252        // attach the tool results to the next user message.
253        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
254        self.tool_uses_by_user_message
255            .entry(next_user_message_id)
256            .or_default()
257            .push(tool_use.id.clone());
258
259        self.pending_tool_uses_by_id.insert(
260            tool_use.id.clone(),
261            PendingToolUse {
262                assistant_message_id,
263                id: tool_use.id,
264                name: tool_use.name.clone(),
265                ui_text: self
266                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
267                    .into(),
268                input: tool_use.input,
269                status: PendingToolUseStatus::Idle,
270            },
271        );
272    }
273
274    pub fn run_pending_tool(
275        &mut self,
276        tool_use_id: LanguageModelToolUseId,
277        ui_text: SharedString,
278        task: Task<()>,
279    ) {
280        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
281            tool_use.ui_text = ui_text.into();
282            tool_use.status = PendingToolUseStatus::Running {
283                _task: task.shared(),
284            };
285        }
286    }
287
288    pub fn confirm_tool_use(
289        &mut self,
290        tool_use_id: LanguageModelToolUseId,
291        ui_text: impl Into<Arc<str>>,
292        input: serde_json::Value,
293        messages: Arc<Vec<LanguageModelRequestMessage>>,
294        tool: Arc<dyn Tool>,
295    ) {
296        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
297            let ui_text = ui_text.into();
298            tool_use.ui_text = ui_text.clone();
299            let confirmation = Confirmation {
300                tool_use_id,
301                input,
302                messages,
303                tool,
304                ui_text,
305            };
306            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
307        }
308    }
309
310    pub fn insert_tool_output(
311        &mut self,
312        tool_use_id: LanguageModelToolUseId,
313        output: Result<String>,
314    ) -> Option<PendingToolUse> {
315        match output {
316            Ok(tool_result) => {
317                self.tool_results.insert(
318                    tool_use_id.clone(),
319                    LanguageModelToolResult {
320                        tool_use_id: tool_use_id.clone(),
321                        content: tool_result.into(),
322                        is_error: false,
323                    },
324                );
325                self.pending_tool_uses_by_id.remove(&tool_use_id)
326            }
327            Err(err) => {
328                self.tool_results.insert(
329                    tool_use_id.clone(),
330                    LanguageModelToolResult {
331                        tool_use_id: tool_use_id.clone(),
332                        content: err.to_string().into(),
333                        is_error: true,
334                    },
335                );
336
337                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
338                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
339                }
340
341                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
342            }
343        }
344    }
345
346    pub fn attach_tool_uses(
347        &self,
348        message_id: MessageId,
349        request_message: &mut LanguageModelRequestMessage,
350    ) {
351        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
352            for tool_use in tool_uses {
353                if self.tool_results.contains_key(&tool_use.id) {
354                    // Do not send tool uses until they are completed
355                    request_message
356                        .content
357                        .push(MessageContent::ToolUse(tool_use.clone()));
358                } else {
359                    log::debug!(
360                        "skipped tool use {:?} because it is still pending",
361                        tool_use
362                    );
363                }
364            }
365        }
366    }
367
368    pub fn attach_tool_results(
369        &self,
370        message_id: MessageId,
371        request_message: &mut LanguageModelRequestMessage,
372    ) {
373        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
374            for tool_use_id in tool_uses {
375                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
376                    request_message.content.push(MessageContent::ToolResult(
377                        LanguageModelToolResult {
378                            tool_use_id: tool_use_id.clone(),
379                            is_error: tool_result.is_error,
380                            content: if tool_result.content.is_empty() {
381                                // Surprisingly, the API fails if we return an empty string here.
382                                // It thinks we are sending a tool use without a tool result.
383                                "<Tool returned an empty string>".into()
384                            } else {
385                                tool_result.content.clone()
386                            },
387                        },
388                    ));
389                }
390            }
391        }
392    }
393}
394
395#[derive(Debug, Clone)]
396pub struct PendingToolUse {
397    pub id: LanguageModelToolUseId,
398    /// The ID of the Assistant message in which the tool use was requested.
399    #[allow(unused)]
400    pub assistant_message_id: MessageId,
401    pub name: Arc<str>,
402    pub ui_text: Arc<str>,
403    pub input: serde_json::Value,
404    pub status: PendingToolUseStatus,
405}
406
407#[derive(Debug, Clone)]
408pub struct Confirmation {
409    pub tool_use_id: LanguageModelToolUseId,
410    pub input: serde_json::Value,
411    pub ui_text: Arc<str>,
412    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
413    pub tool: Arc<dyn Tool>,
414}
415
416#[derive(Debug, Clone)]
417pub enum PendingToolUseStatus {
418    Idle,
419    NeedsConfirmation(Arc<Confirmation>),
420    Running { _task: Shared<Task<()>> },
421    Error(#[allow(unused)] Arc<str>),
422}
423
424impl PendingToolUseStatus {
425    pub fn is_idle(&self) -> bool {
426        matches!(self, PendingToolUseStatus::Idle)
427    }
428
429    pub fn is_error(&self) -> bool {
430        matches!(self, PendingToolUseStatus::Error(_))
431    }
432
433    pub fn needs_confirmation(&self) -> bool {
434        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
435    }
436}