tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14
 15use crate::thread::MessageId;
 16use crate::thread_store::SerializedMessage;
 17
 18#[derive(Debug)]
 19pub struct ToolUse {
 20    pub id: LanguageModelToolUseId,
 21    pub name: SharedString,
 22    pub ui_text: SharedString,
 23    pub status: ToolUseStatus,
 24    pub input: serde_json::Value,
 25    pub icon: ui::IconName,
 26    pub needs_confirmation: bool,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub enum ToolUseStatus {
 31    NeedsConfirmation,
 32    Pending,
 33    Running,
 34    Finished(SharedString),
 35    Error(SharedString),
 36}
 37
 38impl ToolUseStatus {
 39    pub fn text(&self) -> SharedString {
 40        match self {
 41            ToolUseStatus::NeedsConfirmation => "".into(),
 42            ToolUseStatus::Pending => "".into(),
 43            ToolUseStatus::Running => "".into(),
 44            ToolUseStatus::Finished(out) => out.clone(),
 45            ToolUseStatus::Error(out) => out.clone(),
 46        }
 47    }
 48}
 49
 50pub struct ToolUseState {
 51    tools: Arc<ToolWorkingSet>,
 52    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 53    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 54    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 55    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 56}
 57
 58pub const USING_TOOL_MARKER: &str = "<using_tool>";
 59
 60impl ToolUseState {
 61    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 62        Self {
 63            tools,
 64            tool_uses_by_assistant_message: HashMap::default(),
 65            tool_uses_by_user_message: HashMap::default(),
 66            tool_results: HashMap::default(),
 67            pending_tool_uses_by_id: HashMap::default(),
 68        }
 69    }
 70
 71    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 72    ///
 73    /// Accepts a function to filter the tools that should be used to populate the state.
 74    pub fn from_serialized_messages(
 75        tools: Arc<ToolWorkingSet>,
 76        messages: &[SerializedMessage],
 77        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 78    ) -> Self {
 79        let mut this = Self::new(tools);
 80        let mut tool_names_by_id = HashMap::default();
 81
 82        for message in messages {
 83            match message.role {
 84                Role::Assistant => {
 85                    if !message.tool_uses.is_empty() {
 86                        let tool_uses = message
 87                            .tool_uses
 88                            .iter()
 89                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 90                            .map(|tool_use| LanguageModelToolUse {
 91                                id: tool_use.id.clone(),
 92                                name: tool_use.name.clone().into(),
 93                                input: tool_use.input.clone(),
 94                            })
 95                            .collect::<Vec<_>>();
 96
 97                        tool_names_by_id.extend(
 98                            tool_uses
 99                                .iter()
100                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
101                        );
102
103                        this.tool_uses_by_assistant_message
104                            .insert(message.id, tool_uses);
105                    }
106                }
107                Role::User => {
108                    if !message.tool_results.is_empty() {
109                        let tool_uses_by_user_message = this
110                            .tool_uses_by_user_message
111                            .entry(message.id)
112                            .or_default();
113
114                        for tool_result in &message.tool_results {
115                            let tool_use_id = tool_result.tool_use_id.clone();
116                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
117                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
118                                continue;
119                            };
120
121                            if !(filter_by_tool_name)(tool_use.as_ref()) {
122                                continue;
123                            }
124
125                            tool_uses_by_user_message.push(tool_use_id.clone());
126                            this.tool_results.insert(
127                                tool_use_id.clone(),
128                                LanguageModelToolResult {
129                                    tool_use_id,
130                                    tool_name: tool_use.clone(),
131                                    is_error: tool_result.is_error,
132                                    content: tool_result.content.clone(),
133                                },
134                            );
135                        }
136                    }
137                }
138                Role::System => {}
139            }
140        }
141
142        this
143    }
144
145    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
146        let mut pending_tools = Vec::new();
147        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
148            self.tool_results.insert(
149                tool_use_id.clone(),
150                LanguageModelToolResult {
151                    tool_use_id,
152                    tool_name: tool_use.name.clone(),
153                    content: "Tool canceled by user".into(),
154                    is_error: true,
155                },
156            );
157            pending_tools.push(tool_use.clone());
158        }
159        pending_tools
160    }
161
162    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
163        self.pending_tool_uses_by_id.values().collect()
164    }
165
166    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
167        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
168            return Vec::new();
169        };
170
171        let mut tool_uses = Vec::new();
172
173        for tool_use in tool_uses_for_message.iter() {
174            let tool_result = self.tool_results.get(&tool_use.id);
175
176            let status = (|| {
177                if let Some(tool_result) = tool_result {
178                    return if tool_result.is_error {
179                        ToolUseStatus::Error(tool_result.content.clone().into())
180                    } else {
181                        ToolUseStatus::Finished(tool_result.content.clone().into())
182                    };
183                }
184
185                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
186                    match pending_tool_use.status {
187                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
188                        PendingToolUseStatus::NeedsConfirmation { .. } => {
189                            ToolUseStatus::NeedsConfirmation
190                        }
191                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
192                        PendingToolUseStatus::Error(ref err) => {
193                            ToolUseStatus::Error(err.clone().into())
194                        }
195                    }
196                } else {
197                    ToolUseStatus::Pending
198                }
199            })();
200
201            let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
202            {
203                (tool.icon(), tool.needs_confirmation())
204            } else {
205                (IconName::Cog, false)
206            };
207
208            tool_uses.push(ToolUse {
209                id: tool_use.id.clone(),
210                name: tool_use.name.clone().into(),
211                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
212                input: tool_use.input.clone(),
213                status,
214                icon,
215                needs_confirmation,
216            })
217        }
218
219        tool_uses
220    }
221
222    pub fn tool_ui_label(
223        &self,
224        tool_name: &str,
225        input: &serde_json::Value,
226        cx: &App,
227    ) -> SharedString {
228        if let Some(tool) = self.tools.tool(tool_name, cx) {
229            tool.ui_text(input).into()
230        } else {
231            format!("Unknown tool {tool_name:?}").into()
232        }
233    }
234
235    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
236        let empty = Vec::new();
237
238        self.tool_uses_by_user_message
239            .get(&message_id)
240            .unwrap_or(&empty)
241            .iter()
242            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
243            .collect()
244    }
245
246    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
247        self.tool_uses_by_user_message
248            .get(&message_id)
249            .map_or(false, |results| !results.is_empty())
250    }
251
252    pub fn tool_result(
253        &self,
254        tool_use_id: &LanguageModelToolUseId,
255    ) -> Option<&LanguageModelToolResult> {
256        self.tool_results.get(tool_use_id)
257    }
258
259    pub fn request_tool_use(
260        &mut self,
261        assistant_message_id: MessageId,
262        tool_use: LanguageModelToolUse,
263        cx: &App,
264    ) {
265        self.tool_uses_by_assistant_message
266            .entry(assistant_message_id)
267            .or_default()
268            .push(tool_use.clone());
269
270        // The tool use is being requested by the Assistant, so we want to
271        // attach the tool results to the next user message.
272        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
273        self.tool_uses_by_user_message
274            .entry(next_user_message_id)
275            .or_default()
276            .push(tool_use.id.clone());
277
278        self.pending_tool_uses_by_id.insert(
279            tool_use.id.clone(),
280            PendingToolUse {
281                assistant_message_id,
282                id: tool_use.id,
283                name: tool_use.name.clone(),
284                ui_text: self
285                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
286                    .into(),
287                input: tool_use.input,
288                status: PendingToolUseStatus::Idle,
289            },
290        );
291    }
292
293    pub fn run_pending_tool(
294        &mut self,
295        tool_use_id: LanguageModelToolUseId,
296        ui_text: SharedString,
297        task: Task<()>,
298    ) {
299        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
300            tool_use.ui_text = ui_text.into();
301            tool_use.status = PendingToolUseStatus::Running {
302                _task: task.shared(),
303            };
304        }
305    }
306
307    pub fn confirm_tool_use(
308        &mut self,
309        tool_use_id: LanguageModelToolUseId,
310        ui_text: impl Into<Arc<str>>,
311        input: serde_json::Value,
312        messages: Arc<Vec<LanguageModelRequestMessage>>,
313        tool: Arc<dyn Tool>,
314    ) {
315        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
316            let ui_text = ui_text.into();
317            tool_use.ui_text = ui_text.clone();
318            let confirmation = Confirmation {
319                tool_use_id,
320                input,
321                messages,
322                tool,
323                ui_text,
324            };
325            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
326        }
327    }
328
329    pub fn insert_tool_output(
330        &mut self,
331        tool_use_id: LanguageModelToolUseId,
332        tool_name: Arc<str>,
333        output: Result<String>,
334    ) -> Option<PendingToolUse> {
335        match output {
336            Ok(tool_result) => {
337                self.tool_results.insert(
338                    tool_use_id.clone(),
339                    LanguageModelToolResult {
340                        tool_use_id: tool_use_id.clone(),
341                        tool_name,
342                        content: tool_result.into(),
343                        is_error: false,
344                    },
345                );
346                self.pending_tool_uses_by_id.remove(&tool_use_id)
347            }
348            Err(err) => {
349                self.tool_results.insert(
350                    tool_use_id.clone(),
351                    LanguageModelToolResult {
352                        tool_use_id: tool_use_id.clone(),
353                        tool_name,
354                        content: err.to_string().into(),
355                        is_error: true,
356                    },
357                );
358
359                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
360                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
361                }
362
363                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
364            }
365        }
366    }
367
368    pub fn attach_tool_uses(
369        &self,
370        message_id: MessageId,
371        request_message: &mut LanguageModelRequestMessage,
372    ) {
373        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
374            let mut found_tool_use = false;
375
376            for tool_use in tool_uses {
377                if self.tool_results.contains_key(&tool_use.id) {
378                    if !found_tool_use {
379                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
380                        match request_message.content.last_mut() {
381                            Some(MessageContent::Text(txt)) => {
382                                if txt.is_empty() {
383                                    txt.push_str(USING_TOOL_MARKER);
384                                }
385                            }
386                            None | Some(_) => {
387                                request_message
388                                    .content
389                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
390                            }
391                        };
392                    }
393
394                    found_tool_use = true;
395
396                    // Do not send tool uses until they are completed
397                    request_message
398                        .content
399                        .push(MessageContent::ToolUse(tool_use.clone()));
400                } else {
401                    log::debug!(
402                        "skipped tool use {:?} because it is still pending",
403                        tool_use
404                    );
405                }
406            }
407        }
408    }
409
410    pub fn attach_tool_results(
411        &self,
412        message_id: MessageId,
413        request_message: &mut LanguageModelRequestMessage,
414    ) {
415        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
416            for tool_use_id in tool_uses {
417                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
418                    request_message.content.push(MessageContent::ToolResult(
419                        LanguageModelToolResult {
420                            tool_use_id: tool_use_id.clone(),
421                            tool_name: tool_result.tool_name.clone(),
422                            is_error: tool_result.is_error,
423                            content: if tool_result.content.is_empty() {
424                                // Surprisingly, the API fails if we return an empty string here.
425                                // It thinks we are sending a tool use without a tool result.
426                                "<Tool returned an empty string>".into()
427                            } else {
428                                tool_result.content.clone()
429                            },
430                        },
431                    ));
432                }
433            }
434        }
435    }
436}
437
438#[derive(Debug, Clone)]
439pub struct PendingToolUse {
440    pub id: LanguageModelToolUseId,
441    /// The ID of the Assistant message in which the tool use was requested.
442    #[allow(unused)]
443    pub assistant_message_id: MessageId,
444    pub name: Arc<str>,
445    pub ui_text: Arc<str>,
446    pub input: serde_json::Value,
447    pub status: PendingToolUseStatus,
448}
449
450#[derive(Debug, Clone)]
451pub struct Confirmation {
452    pub tool_use_id: LanguageModelToolUseId,
453    pub input: serde_json::Value,
454    pub ui_text: Arc<str>,
455    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
456    pub tool: Arc<dyn Tool>,
457}
458
459#[derive(Debug, Clone)]
460pub enum PendingToolUseStatus {
461    Idle,
462    NeedsConfirmation(Arc<Confirmation>),
463    Running { _task: Shared<Task<()>> },
464    Error(#[allow(unused)] Arc<str>),
465}
466
467impl PendingToolUseStatus {
468    pub fn is_idle(&self) -> bool {
469        matches!(self, PendingToolUseStatus::Idle)
470    }
471
472    pub fn is_error(&self) -> bool {
473        matches!(self, PendingToolUseStatus::Error(_))
474    }
475
476    pub fn needs_confirmation(&self) -> bool {
477        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
478    }
479}