tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashMap;
  5use futures::future::Shared;
  6use futures::FutureExt as _;
  7use gpui::{SharedString, Task};
  8use language_model::{
  9    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 10    LanguageModelToolUseId, MessageContent, Role,
 11};
 12
 13use crate::thread::MessageId;
 14use crate::thread_store::SerializedMessage;
 15
 16#[derive(Debug)]
 17pub struct ToolUse {
 18    pub id: LanguageModelToolUseId,
 19    pub name: SharedString,
 20    pub status: ToolUseStatus,
 21    pub input: serde_json::Value,
 22}
 23
 24#[derive(Debug, Clone)]
 25pub enum ToolUseStatus {
 26    Pending,
 27    Running,
 28    Finished(SharedString),
 29    Error(SharedString),
 30}
 31
 32pub struct ToolUseState {
 33    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 34    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 35    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 36    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 37}
 38
 39impl ToolUseState {
 40    pub fn new() -> Self {
 41        Self {
 42            tool_uses_by_assistant_message: HashMap::default(),
 43            tool_uses_by_user_message: HashMap::default(),
 44            tool_results: HashMap::default(),
 45            pending_tool_uses_by_id: HashMap::default(),
 46        }
 47    }
 48
 49    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 50    ///
 51    /// Accepts a function to filter the tools that should be used to populate the state.
 52    pub fn from_serialized_messages(
 53        messages: &[SerializedMessage],
 54        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 55    ) -> Self {
 56        let mut this = Self::new();
 57        let mut tool_names_by_id = HashMap::default();
 58
 59        for message in messages {
 60            match message.role {
 61                Role::Assistant => {
 62                    if !message.tool_uses.is_empty() {
 63                        let tool_uses = message
 64                            .tool_uses
 65                            .iter()
 66                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 67                            .map(|tool_use| LanguageModelToolUse {
 68                                id: tool_use.id.clone(),
 69                                name: tool_use.name.clone().into(),
 70                                input: tool_use.input.clone(),
 71                            })
 72                            .collect::<Vec<_>>();
 73
 74                        tool_names_by_id.extend(
 75                            tool_uses
 76                                .iter()
 77                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 78                        );
 79
 80                        this.tool_uses_by_assistant_message
 81                            .insert(message.id, tool_uses);
 82                    }
 83                }
 84                Role::User => {
 85                    if !message.tool_results.is_empty() {
 86                        let tool_uses_by_user_message = this
 87                            .tool_uses_by_user_message
 88                            .entry(message.id)
 89                            .or_default();
 90
 91                        for tool_result in &message.tool_results {
 92                            let tool_use_id = tool_result.tool_use_id.clone();
 93                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 94                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
 95                                continue;
 96                            };
 97
 98                            if !(filter_by_tool_name)(tool_use.as_ref()) {
 99                                continue;
100                            }
101
102                            tool_uses_by_user_message.push(tool_use_id.clone());
103                            this.tool_results.insert(
104                                tool_use_id.clone(),
105                                LanguageModelToolResult {
106                                    tool_use_id,
107                                    is_error: tool_result.is_error,
108                                    content: tool_result.content.clone(),
109                                },
110                            );
111                        }
112                    }
113                }
114                Role::System => {}
115            }
116        }
117
118        this
119    }
120
121    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
122        let mut pending_tools = Vec::new();
123        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
124            self.tool_results.insert(
125                tool_use_id.clone(),
126                LanguageModelToolResult {
127                    tool_use_id,
128                    content: "Tool canceled by user".into(),
129                    is_error: true,
130                },
131            );
132            pending_tools.push(tool_use.clone());
133        }
134        pending_tools
135    }
136
137    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
138        self.pending_tool_uses_by_id.values().collect()
139    }
140
141    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
142        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
143            return Vec::new();
144        };
145
146        let mut tool_uses = Vec::new();
147
148        for tool_use in tool_uses_for_message.iter() {
149            let tool_result = self.tool_results.get(&tool_use.id);
150
151            let status = (|| {
152                if let Some(tool_result) = tool_result {
153                    return if tool_result.is_error {
154                        ToolUseStatus::Error(tool_result.content.clone().into())
155                    } else {
156                        ToolUseStatus::Finished(tool_result.content.clone().into())
157                    };
158                }
159
160                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
161                    return match pending_tool_use.status {
162                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
163                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
164                        PendingToolUseStatus::Error(ref err) => {
165                            ToolUseStatus::Error(err.clone().into())
166                        }
167                    };
168                }
169
170                ToolUseStatus::Pending
171            })();
172
173            tool_uses.push(ToolUse {
174                id: tool_use.id.clone(),
175                name: tool_use.name.clone().into(),
176                input: tool_use.input.clone(),
177                status,
178            })
179        }
180
181        tool_uses
182    }
183
184    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
185        let empty = Vec::new();
186
187        self.tool_uses_by_user_message
188            .get(&message_id)
189            .unwrap_or(&empty)
190            .iter()
191            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
192            .collect()
193    }
194
195    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
196        self.tool_uses_by_user_message
197            .get(&message_id)
198            .map_or(false, |results| !results.is_empty())
199    }
200
201    pub fn tool_result(
202        &self,
203        tool_use_id: &LanguageModelToolUseId,
204    ) -> Option<&LanguageModelToolResult> {
205        self.tool_results.get(tool_use_id)
206    }
207
208    pub fn request_tool_use(
209        &mut self,
210        assistant_message_id: MessageId,
211        tool_use: LanguageModelToolUse,
212    ) {
213        self.tool_uses_by_assistant_message
214            .entry(assistant_message_id)
215            .or_default()
216            .push(tool_use.clone());
217
218        // The tool use is being requested by the Assistant, so we want to
219        // attach the tool results to the next user message.
220        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
221        self.tool_uses_by_user_message
222            .entry(next_user_message_id)
223            .or_default()
224            .push(tool_use.id.clone());
225
226        self.pending_tool_uses_by_id.insert(
227            tool_use.id.clone(),
228            PendingToolUse {
229                assistant_message_id,
230                id: tool_use.id,
231                name: tool_use.name,
232                input: tool_use.input,
233                status: PendingToolUseStatus::Idle,
234            },
235        );
236    }
237
238    pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
239        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
240            tool_use.status = PendingToolUseStatus::Running {
241                _task: task.shared(),
242            };
243        }
244    }
245
246    pub fn insert_tool_output(
247        &mut self,
248        tool_use_id: LanguageModelToolUseId,
249        output: Result<String>,
250    ) -> Option<PendingToolUse> {
251        match output {
252            Ok(tool_result) => {
253                self.tool_results.insert(
254                    tool_use_id.clone(),
255                    LanguageModelToolResult {
256                        tool_use_id: tool_use_id.clone(),
257                        content: tool_result.into(),
258                        is_error: false,
259                    },
260                );
261                self.pending_tool_uses_by_id.remove(&tool_use_id)
262            }
263            Err(err) => {
264                self.tool_results.insert(
265                    tool_use_id.clone(),
266                    LanguageModelToolResult {
267                        tool_use_id: tool_use_id.clone(),
268                        content: err.to_string().into(),
269                        is_error: true,
270                    },
271                );
272
273                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
274                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
275                }
276
277                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
278            }
279        }
280    }
281
282    pub fn attach_tool_uses(
283        &self,
284        message_id: MessageId,
285        request_message: &mut LanguageModelRequestMessage,
286    ) {
287        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
288            for tool_use in tool_uses {
289                if self.tool_results.contains_key(&tool_use.id) {
290                    // Do not send tool uses until they are completed
291                    request_message
292                        .content
293                        .push(MessageContent::ToolUse(tool_use.clone()));
294                } else {
295                    log::debug!(
296                        "skipped tool use {:?} because it is still pending",
297                        tool_use
298                    );
299                }
300            }
301        }
302    }
303
304    pub fn attach_tool_results(
305        &self,
306        message_id: MessageId,
307        request_message: &mut LanguageModelRequestMessage,
308    ) {
309        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
310            for tool_use_id in tool_uses {
311                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
312                    request_message.content.push(MessageContent::ToolResult(
313                        LanguageModelToolResult {
314                            tool_use_id: tool_use_id.clone(),
315                            is_error: tool_result.is_error,
316                            content: if tool_result.content.is_empty() {
317                                // Surprisingly, the API fails if we return an empty string here.
318                                // It thinks we are sending a tool use without a tool result.
319                                "<Tool returned an empty string>".into()
320                            } else {
321                                tool_result.content.clone()
322                            },
323                        },
324                    ));
325                }
326            }
327        }
328    }
329}
330
331#[derive(Debug, Clone)]
332pub struct PendingToolUse {
333    pub id: LanguageModelToolUseId,
334    /// The ID of the Assistant message in which the tool use was requested.
335    #[allow(unused)]
336    pub assistant_message_id: MessageId,
337    pub name: Arc<str>,
338    pub input: serde_json::Value,
339    pub status: PendingToolUseStatus,
340}
341
342#[derive(Debug, Clone)]
343pub enum PendingToolUseStatus {
344    Idle,
345    Running { _task: Shared<Task<()>> },
346    Error(#[allow(unused)] Arc<str>),
347}
348
349impl PendingToolUseStatus {
350    pub fn is_idle(&self) -> bool {
351        matches!(self, PendingToolUseStatus::Idle)
352    }
353
354    pub fn is_error(&self) -> bool {
355        matches!(self, PendingToolUseStatus::Error(_))
356    }
357}