tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::FutureExt as _;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13
 14use crate::thread::MessageId;
 15use crate::thread_store::SerializedMessage;
 16
 17#[derive(Debug)]
 18pub struct ToolUse {
 19    pub id: LanguageModelToolUseId,
 20    pub name: SharedString,
 21    pub ui_text: SharedString,
 22    pub status: ToolUseStatus,
 23    pub input: serde_json::Value,
 24}
 25
 26#[derive(Debug, Clone)]
 27pub enum ToolUseStatus {
 28    NeedsConfirmation,
 29    Pending,
 30    Running,
 31    Finished(SharedString),
 32    Error(SharedString),
 33}
 34
 35pub struct ToolUseState {
 36    tools: Arc<ToolWorkingSet>,
 37    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 38    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 39    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 40    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 41}
 42
 43impl ToolUseState {
 44    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 45        Self {
 46            tools,
 47            tool_uses_by_assistant_message: HashMap::default(),
 48            tool_uses_by_user_message: HashMap::default(),
 49            tool_results: HashMap::default(),
 50            pending_tool_uses_by_id: HashMap::default(),
 51        }
 52    }
 53
 54    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 55    ///
 56    /// Accepts a function to filter the tools that should be used to populate the state.
 57    pub fn from_serialized_messages(
 58        tools: Arc<ToolWorkingSet>,
 59        messages: &[SerializedMessage],
 60        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 61    ) -> Self {
 62        let mut this = Self::new(tools);
 63        let mut tool_names_by_id = HashMap::default();
 64
 65        for message in messages {
 66            match message.role {
 67                Role::Assistant => {
 68                    if !message.tool_uses.is_empty() {
 69                        let tool_uses = message
 70                            .tool_uses
 71                            .iter()
 72                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 73                            .map(|tool_use| LanguageModelToolUse {
 74                                id: tool_use.id.clone(),
 75                                name: tool_use.name.clone().into(),
 76                                input: tool_use.input.clone(),
 77                            })
 78                            .collect::<Vec<_>>();
 79
 80                        tool_names_by_id.extend(
 81                            tool_uses
 82                                .iter()
 83                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 84                        );
 85
 86                        this.tool_uses_by_assistant_message
 87                            .insert(message.id, tool_uses);
 88                    }
 89                }
 90                Role::User => {
 91                    if !message.tool_results.is_empty() {
 92                        let tool_uses_by_user_message = this
 93                            .tool_uses_by_user_message
 94                            .entry(message.id)
 95                            .or_default();
 96
 97                        for tool_result in &message.tool_results {
 98                            let tool_use_id = tool_result.tool_use_id.clone();
 99                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
101                                continue;
102                            };
103
104                            if !(filter_by_tool_name)(tool_use.as_ref()) {
105                                continue;
106                            }
107
108                            tool_uses_by_user_message.push(tool_use_id.clone());
109                            this.tool_results.insert(
110                                tool_use_id.clone(),
111                                LanguageModelToolResult {
112                                    tool_use_id,
113                                    is_error: tool_result.is_error,
114                                    content: tool_result.content.clone(),
115                                },
116                            );
117                        }
118                    }
119                }
120                Role::System => {}
121            }
122        }
123
124        this
125    }
126
127    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128        let mut pending_tools = Vec::new();
129        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
130            self.tool_results.insert(
131                tool_use_id.clone(),
132                LanguageModelToolResult {
133                    tool_use_id,
134                    content: "Tool canceled by user".into(),
135                    is_error: true,
136                },
137            );
138            pending_tools.push(tool_use.clone());
139        }
140        pending_tools
141    }
142
143    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
144        self.pending_tool_uses_by_id.values().collect()
145    }
146
147    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
148        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
149            return Vec::new();
150        };
151
152        let mut tool_uses = Vec::new();
153
154        for tool_use in tool_uses_for_message.iter() {
155            let tool_result = self.tool_results.get(&tool_use.id);
156
157            let status = (|| {
158                if let Some(tool_result) = tool_result {
159                    return if tool_result.is_error {
160                        ToolUseStatus::Error(tool_result.content.clone().into())
161                    } else {
162                        ToolUseStatus::Finished(tool_result.content.clone().into())
163                    };
164                }
165
166                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
167                    match pending_tool_use.status {
168                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
169                        PendingToolUseStatus::NeedsConfirmation { .. } => {
170                            ToolUseStatus::NeedsConfirmation
171                        }
172                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
173                        PendingToolUseStatus::Error(ref err) => {
174                            ToolUseStatus::Error(err.clone().into())
175                        }
176                    }
177                } else {
178                    ToolUseStatus::Pending
179                }
180            })();
181
182            tool_uses.push(ToolUse {
183                id: tool_use.id.clone(),
184                name: tool_use.name.clone().into(),
185                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
186                input: tool_use.input.clone(),
187                status,
188            })
189        }
190
191        tool_uses
192    }
193
194    pub fn tool_ui_label(
195        &self,
196        tool_name: &str,
197        input: &serde_json::Value,
198        cx: &App,
199    ) -> SharedString {
200        if let Some(tool) = self.tools.tool(tool_name, cx) {
201            tool.ui_text(input).into()
202        } else {
203            "Unknown tool".into()
204        }
205    }
206
207    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
208        let empty = Vec::new();
209
210        self.tool_uses_by_user_message
211            .get(&message_id)
212            .unwrap_or(&empty)
213            .iter()
214            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
215            .collect()
216    }
217
218    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
219        self.tool_uses_by_user_message
220            .get(&message_id)
221            .map_or(false, |results| !results.is_empty())
222    }
223
224    pub fn tool_result(
225        &self,
226        tool_use_id: &LanguageModelToolUseId,
227    ) -> Option<&LanguageModelToolResult> {
228        self.tool_results.get(tool_use_id)
229    }
230
231    pub fn request_tool_use(
232        &mut self,
233        assistant_message_id: MessageId,
234        tool_use: LanguageModelToolUse,
235        cx: &App,
236    ) {
237        self.tool_uses_by_assistant_message
238            .entry(assistant_message_id)
239            .or_default()
240            .push(tool_use.clone());
241
242        // The tool use is being requested by the Assistant, so we want to
243        // attach the tool results to the next user message.
244        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
245        self.tool_uses_by_user_message
246            .entry(next_user_message_id)
247            .or_default()
248            .push(tool_use.id.clone());
249
250        self.pending_tool_uses_by_id.insert(
251            tool_use.id.clone(),
252            PendingToolUse {
253                assistant_message_id,
254                id: tool_use.id,
255                name: tool_use.name.clone(),
256                ui_text: self
257                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
258                    .into(),
259                input: tool_use.input,
260                status: PendingToolUseStatus::Idle,
261            },
262        );
263    }
264
265    pub fn run_pending_tool(
266        &mut self,
267        tool_use_id: LanguageModelToolUseId,
268        ui_text: SharedString,
269        task: Task<()>,
270    ) {
271        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
272            tool_use.ui_text = ui_text.into();
273            tool_use.status = PendingToolUseStatus::Running {
274                _task: task.shared(),
275            };
276        }
277    }
278
279    pub fn confirm_tool_use(
280        &mut self,
281        tool_use_id: LanguageModelToolUseId,
282        ui_text: impl Into<Arc<str>>,
283        input: serde_json::Value,
284        messages: Arc<Vec<LanguageModelRequestMessage>>,
285        tool: Arc<dyn Tool>,
286    ) {
287        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
288            let ui_text = ui_text.into();
289            tool_use.ui_text = ui_text.clone();
290            let confirmation = Confirmation {
291                tool_use_id,
292                input,
293                messages,
294                tool,
295                ui_text,
296            };
297            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
298        }
299    }
300
301    pub fn insert_tool_output(
302        &mut self,
303        tool_use_id: LanguageModelToolUseId,
304        output: Result<String>,
305    ) -> Option<PendingToolUse> {
306        match output {
307            Ok(tool_result) => {
308                self.tool_results.insert(
309                    tool_use_id.clone(),
310                    LanguageModelToolResult {
311                        tool_use_id: tool_use_id.clone(),
312                        content: tool_result.into(),
313                        is_error: false,
314                    },
315                );
316                self.pending_tool_uses_by_id.remove(&tool_use_id)
317            }
318            Err(err) => {
319                self.tool_results.insert(
320                    tool_use_id.clone(),
321                    LanguageModelToolResult {
322                        tool_use_id: tool_use_id.clone(),
323                        content: err.to_string().into(),
324                        is_error: true,
325                    },
326                );
327
328                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
329                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
330                }
331
332                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
333            }
334        }
335    }
336
337    pub fn attach_tool_uses(
338        &self,
339        message_id: MessageId,
340        request_message: &mut LanguageModelRequestMessage,
341    ) {
342        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
343            for tool_use in tool_uses {
344                if self.tool_results.contains_key(&tool_use.id) {
345                    // Do not send tool uses until they are completed
346                    request_message
347                        .content
348                        .push(MessageContent::ToolUse(tool_use.clone()));
349                } else {
350                    log::debug!(
351                        "skipped tool use {:?} because it is still pending",
352                        tool_use
353                    );
354                }
355            }
356        }
357    }
358
359    pub fn attach_tool_results(
360        &self,
361        message_id: MessageId,
362        request_message: &mut LanguageModelRequestMessage,
363    ) {
364        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
365            for tool_use_id in tool_uses {
366                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
367                    request_message.content.push(MessageContent::ToolResult(
368                        LanguageModelToolResult {
369                            tool_use_id: tool_use_id.clone(),
370                            is_error: tool_result.is_error,
371                            content: if tool_result.content.is_empty() {
372                                // Surprisingly, the API fails if we return an empty string here.
373                                // It thinks we are sending a tool use without a tool result.
374                                "<Tool returned an empty string>".into()
375                            } else {
376                                tool_result.content.clone()
377                            },
378                        },
379                    ));
380                }
381            }
382        }
383    }
384}
385
386#[derive(Debug, Clone)]
387pub struct PendingToolUse {
388    pub id: LanguageModelToolUseId,
389    /// The ID of the Assistant message in which the tool use was requested.
390    #[allow(unused)]
391    pub assistant_message_id: MessageId,
392    pub name: Arc<str>,
393    pub ui_text: Arc<str>,
394    pub input: serde_json::Value,
395    pub status: PendingToolUseStatus,
396}
397
398#[derive(Debug, Clone)]
399pub struct Confirmation {
400    pub tool_use_id: LanguageModelToolUseId,
401    pub input: serde_json::Value,
402    pub ui_text: Arc<str>,
403    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
404    pub tool: Arc<dyn Tool>,
405}
406
407#[derive(Debug, Clone)]
408pub enum PendingToolUseStatus {
409    Idle,
410    NeedsConfirmation(Arc<Confirmation>),
411    Running { _task: Shared<Task<()>> },
412    Error(#[allow(unused)] Arc<str>),
413}
414
415impl PendingToolUseStatus {
416    pub fn is_idle(&self) -> bool {
417        matches!(self, PendingToolUseStatus::Idle)
418    }
419
420    pub fn is_error(&self) -> bool {
421        matches!(self, PendingToolUseStatus::Error(_))
422    }
423
424    pub fn needs_confirmation(&self) -> bool {
425        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
426    }
427}