tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::FutureExt as _;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13
 14use crate::thread::MessageId;
 15use crate::thread_store::SerializedMessage;
 16
 17#[derive(Debug)]
 18pub struct ToolUse {
 19    pub id: LanguageModelToolUseId,
 20    pub name: SharedString,
 21    pub ui_text: SharedString,
 22    pub status: ToolUseStatus,
 23    pub input: serde_json::Value,
 24}
 25
 26#[derive(Debug, Clone)]
 27pub enum ToolUseStatus {
 28    Pending,
 29    Running,
 30    Finished(SharedString),
 31    Error(SharedString),
 32}
 33
 34pub struct ToolUseState {
 35    tools: Arc<ToolWorkingSet>,
 36    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 37    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 38    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 39    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_uses_by_user_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    pub fn from_serialized_messages(
 57        tools: Arc<ToolWorkingSet>,
 58        messages: &[SerializedMessage],
 59        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 60    ) -> Self {
 61        let mut this = Self::new(tools);
 62        let mut tool_names_by_id = HashMap::default();
 63
 64        for message in messages {
 65            match message.role {
 66                Role::Assistant => {
 67                    if !message.tool_uses.is_empty() {
 68                        let tool_uses = message
 69                            .tool_uses
 70                            .iter()
 71                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                input: tool_use.input.clone(),
 76                            })
 77                            .collect::<Vec<_>>();
 78
 79                        tool_names_by_id.extend(
 80                            tool_uses
 81                                .iter()
 82                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 83                        );
 84
 85                        this.tool_uses_by_assistant_message
 86                            .insert(message.id, tool_uses);
 87                    }
 88                }
 89                Role::User => {
 90                    if !message.tool_results.is_empty() {
 91                        let tool_uses_by_user_message = this
 92                            .tool_uses_by_user_message
 93                            .entry(message.id)
 94                            .or_default();
 95
 96                        for tool_result in &message.tool_results {
 97                            let tool_use_id = tool_result.tool_use_id.clone();
 98                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 99                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
100                                continue;
101                            };
102
103                            if !(filter_by_tool_name)(tool_use.as_ref()) {
104                                continue;
105                            }
106
107                            tool_uses_by_user_message.push(tool_use_id.clone());
108                            this.tool_results.insert(
109                                tool_use_id.clone(),
110                                LanguageModelToolResult {
111                                    tool_use_id,
112                                    is_error: tool_result.is_error,
113                                    content: tool_result.content.clone(),
114                                },
115                            );
116                        }
117                    }
118                }
119                Role::System => {}
120            }
121        }
122
123        this
124    }
125
126    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
127        let mut pending_tools = Vec::new();
128        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
129            self.tool_results.insert(
130                tool_use_id.clone(),
131                LanguageModelToolResult {
132                    tool_use_id,
133                    content: "Tool canceled by user".into(),
134                    is_error: true,
135                },
136            );
137            pending_tools.push(tool_use.clone());
138        }
139        pending_tools
140    }
141
142    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
143        self.pending_tool_uses_by_id.values().collect()
144    }
145
146    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
147        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
148            return Vec::new();
149        };
150
151        let mut tool_uses = Vec::new();
152
153        for tool_use in tool_uses_for_message.iter() {
154            let tool_result = self.tool_results.get(&tool_use.id);
155
156            let status = (|| {
157                if let Some(tool_result) = tool_result {
158                    return if tool_result.is_error {
159                        ToolUseStatus::Error(tool_result.content.clone().into())
160                    } else {
161                        ToolUseStatus::Finished(tool_result.content.clone().into())
162                    };
163                }
164
165                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
166                    return match pending_tool_use.status {
167                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
168                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
169                        PendingToolUseStatus::Error(ref err) => {
170                            ToolUseStatus::Error(err.clone().into())
171                        }
172                    };
173                }
174
175                ToolUseStatus::Pending
176            })();
177
178            tool_uses.push(ToolUse {
179                id: tool_use.id.clone(),
180                name: tool_use.name.clone().into(),
181                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
182                input: tool_use.input.clone(),
183                status,
184            })
185        }
186
187        tool_uses
188    }
189
190    pub fn tool_ui_label(
191        &self,
192        tool_name: &str,
193        input: &serde_json::Value,
194        cx: &App,
195    ) -> SharedString {
196        if let Some(tool) = self.tools.tool(tool_name, cx) {
197            tool.ui_text(input).into()
198        } else {
199            "Unknown tool".into()
200        }
201    }
202
203    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
204        let empty = Vec::new();
205
206        self.tool_uses_by_user_message
207            .get(&message_id)
208            .unwrap_or(&empty)
209            .iter()
210            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
211            .collect()
212    }
213
214    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
215        self.tool_uses_by_user_message
216            .get(&message_id)
217            .map_or(false, |results| !results.is_empty())
218    }
219
220    pub fn tool_result(
221        &self,
222        tool_use_id: &LanguageModelToolUseId,
223    ) -> Option<&LanguageModelToolResult> {
224        self.tool_results.get(tool_use_id)
225    }
226
227    pub fn request_tool_use(
228        &mut self,
229        assistant_message_id: MessageId,
230        tool_use: LanguageModelToolUse,
231        cx: &App,
232    ) {
233        self.tool_uses_by_assistant_message
234            .entry(assistant_message_id)
235            .or_default()
236            .push(tool_use.clone());
237
238        // The tool use is being requested by the Assistant, so we want to
239        // attach the tool results to the next user message.
240        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
241        self.tool_uses_by_user_message
242            .entry(next_user_message_id)
243            .or_default()
244            .push(tool_use.id.clone());
245
246        self.pending_tool_uses_by_id.insert(
247            tool_use.id.clone(),
248            PendingToolUse {
249                assistant_message_id,
250                id: tool_use.id,
251                name: tool_use.name.clone(),
252                ui_text: self
253                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
254                    .into(),
255                input: tool_use.input,
256                status: PendingToolUseStatus::Idle,
257            },
258        );
259    }
260
261    pub fn run_pending_tool(
262        &mut self,
263        tool_use_id: LanguageModelToolUseId,
264        ui_text: SharedString,
265        task: Task<()>,
266    ) {
267        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
268            tool_use.ui_text = ui_text.into();
269            tool_use.status = PendingToolUseStatus::Running {
270                _task: task.shared(),
271            };
272        }
273    }
274
275    pub fn insert_tool_output(
276        &mut self,
277        tool_use_id: LanguageModelToolUseId,
278        output: Result<String>,
279    ) -> Option<PendingToolUse> {
280        match output {
281            Ok(tool_result) => {
282                self.tool_results.insert(
283                    tool_use_id.clone(),
284                    LanguageModelToolResult {
285                        tool_use_id: tool_use_id.clone(),
286                        content: tool_result.into(),
287                        is_error: false,
288                    },
289                );
290                self.pending_tool_uses_by_id.remove(&tool_use_id)
291            }
292            Err(err) => {
293                self.tool_results.insert(
294                    tool_use_id.clone(),
295                    LanguageModelToolResult {
296                        tool_use_id: tool_use_id.clone(),
297                        content: err.to_string().into(),
298                        is_error: true,
299                    },
300                );
301
302                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
303                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
304                }
305
306                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
307            }
308        }
309    }
310
311    pub fn attach_tool_uses(
312        &self,
313        message_id: MessageId,
314        request_message: &mut LanguageModelRequestMessage,
315    ) {
316        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
317            for tool_use in tool_uses {
318                if self.tool_results.contains_key(&tool_use.id) {
319                    // Do not send tool uses until they are completed
320                    request_message
321                        .content
322                        .push(MessageContent::ToolUse(tool_use.clone()));
323                } else {
324                    log::debug!(
325                        "skipped tool use {:?} because it is still pending",
326                        tool_use
327                    );
328                }
329            }
330        }
331    }
332
333    pub fn attach_tool_results(
334        &self,
335        message_id: MessageId,
336        request_message: &mut LanguageModelRequestMessage,
337    ) {
338        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
339            for tool_use_id in tool_uses {
340                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
341                    request_message.content.push(MessageContent::ToolResult(
342                        LanguageModelToolResult {
343                            tool_use_id: tool_use_id.clone(),
344                            is_error: tool_result.is_error,
345                            content: if tool_result.content.is_empty() {
346                                // Surprisingly, the API fails if we return an empty string here.
347                                // It thinks we are sending a tool use without a tool result.
348                                "<Tool returned an empty string>".into()
349                            } else {
350                                tool_result.content.clone()
351                            },
352                        },
353                    ));
354                }
355            }
356        }
357    }
358}
359
360#[derive(Debug, Clone)]
361pub struct PendingToolUse {
362    pub id: LanguageModelToolUseId,
363    /// The ID of the Assistant message in which the tool use was requested.
364    #[allow(unused)]
365    pub assistant_message_id: MessageId,
366    pub name: Arc<str>,
367    pub ui_text: Arc<str>,
368    pub input: serde_json::Value,
369    pub status: PendingToolUseStatus,
370}
371
372#[derive(Debug, Clone)]
373pub enum PendingToolUseStatus {
374    Idle,
375    Running { _task: Shared<Task<()>> },
376    Error(#[allow(unused)] Arc<str>),
377}
378
379impl PendingToolUseStatus {
380    pub fn is_idle(&self) -> bool {
381        matches!(self, PendingToolUseStatus::Idle)
382    }
383
384    pub fn is_error(&self) -> bool {
385        matches!(self, PendingToolUseStatus::Error(_))
386    }
387}