tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashMap;
  5use futures::future::Shared;
  6use futures::FutureExt as _;
  7use gpui::{SharedString, Task};
  8use language_model::{
  9    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 10    LanguageModelToolUseId, MessageContent,
 11};
 12
 13use crate::thread::MessageId;
 14
 15#[derive(Debug)]
 16pub struct ToolUse {
 17    pub id: LanguageModelToolUseId,
 18    pub name: SharedString,
 19    pub status: ToolUseStatus,
 20    pub input: serde_json::Value,
 21}
 22
 23#[derive(Debug, Clone)]
 24pub enum ToolUseStatus {
 25    Pending,
 26    Running,
 27    Finished(SharedString),
 28    Error(SharedString),
 29}
 30
 31#[derive(Default)]
 32pub struct ToolUseState {
 33    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 34    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 35    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 36    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 37}
 38
 39impl ToolUseState {
 40    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 41        self.pending_tool_uses_by_id.values().collect()
 42    }
 43
 44    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 45        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
 46            return Vec::new();
 47        };
 48
 49        let mut tool_uses = Vec::new();
 50
 51        for tool_use in tool_uses_for_message.iter() {
 52            let tool_result = self.tool_results.get(&tool_use.id);
 53
 54            let status = (|| {
 55                if let Some(tool_result) = tool_result {
 56                    return if tool_result.is_error {
 57                        ToolUseStatus::Error(tool_result.content.clone().into())
 58                    } else {
 59                        ToolUseStatus::Finished(tool_result.content.clone().into())
 60                    };
 61                }
 62
 63                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
 64                    return match pending_tool_use.status {
 65                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
 66                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
 67                        PendingToolUseStatus::Error(ref err) => {
 68                            ToolUseStatus::Error(err.clone().into())
 69                        }
 70                    };
 71                }
 72
 73                ToolUseStatus::Pending
 74            })();
 75
 76            tool_uses.push(ToolUse {
 77                id: tool_use.id.clone(),
 78                name: tool_use.name.clone().into(),
 79                input: tool_use.input.clone(),
 80                status,
 81            })
 82        }
 83
 84        tool_uses
 85    }
 86
 87    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 88        self.tool_uses_by_user_message
 89            .get(&message_id)
 90            .map_or(false, |results| !results.is_empty())
 91    }
 92
 93    pub fn request_tool_use(
 94        &mut self,
 95        assistant_message_id: MessageId,
 96        tool_use: LanguageModelToolUse,
 97    ) {
 98        self.tool_uses_by_assistant_message
 99            .entry(assistant_message_id)
100            .or_default()
101            .push(tool_use.clone());
102
103        // The tool use is being requested by the Assistant, so we want to
104        // attach the tool results to the next user message.
105        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
106        self.tool_uses_by_user_message
107            .entry(next_user_message_id)
108            .or_default()
109            .push(tool_use.id.clone());
110
111        self.pending_tool_uses_by_id.insert(
112            tool_use.id.clone(),
113            PendingToolUse {
114                assistant_message_id,
115                id: tool_use.id,
116                name: tool_use.name,
117                input: tool_use.input,
118                status: PendingToolUseStatus::Idle,
119            },
120        );
121    }
122
123    pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
124        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
125            tool_use.status = PendingToolUseStatus::Running {
126                _task: task.shared(),
127            };
128        }
129    }
130
131    pub fn insert_tool_output(
132        &mut self,
133        tool_use_id: LanguageModelToolUseId,
134        output: Result<String>,
135    ) {
136        match output {
137            Ok(output) => {
138                self.tool_results.insert(
139                    tool_use_id.clone(),
140                    LanguageModelToolResult {
141                        tool_use_id: tool_use_id.clone(),
142                        content: output.into(),
143                        is_error: false,
144                    },
145                );
146                self.pending_tool_uses_by_id.remove(&tool_use_id);
147            }
148            Err(err) => {
149                self.tool_results.insert(
150                    tool_use_id.clone(),
151                    LanguageModelToolResult {
152                        tool_use_id: tool_use_id.clone(),
153                        content: err.to_string().into(),
154                        is_error: true,
155                    },
156                );
157
158                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
159                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
160                }
161            }
162        }
163    }
164
165    pub fn attach_tool_uses(
166        &self,
167        message_id: MessageId,
168        request_message: &mut LanguageModelRequestMessage,
169    ) {
170        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
171            for tool_use in tool_uses {
172                request_message
173                    .content
174                    .push(MessageContent::ToolUse(tool_use.clone()));
175            }
176        }
177    }
178
179    pub fn attach_tool_results(
180        &self,
181        message_id: MessageId,
182        request_message: &mut LanguageModelRequestMessage,
183    ) {
184        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
185            for tool_use_id in tool_uses {
186                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
187                    request_message
188                        .content
189                        .push(MessageContent::ToolResult(tool_result.clone()));
190                }
191            }
192        }
193    }
194}
195
196#[derive(Debug, Clone)]
197pub struct PendingToolUse {
198    pub id: LanguageModelToolUseId,
199    /// The ID of the Assistant message in which the tool use was requested.
200    pub assistant_message_id: MessageId,
201    pub name: Arc<str>,
202    pub input: serde_json::Value,
203    pub status: PendingToolUseStatus,
204}
205
206#[derive(Debug, Clone)]
207pub enum PendingToolUseStatus {
208    Idle,
209    Running { _task: Shared<Task<()>> },
210    Error(#[allow(unused)] Arc<str>),
211}
212
213impl PendingToolUseStatus {
214    pub fn is_idle(&self) -> bool {
215        matches!(self, PendingToolUseStatus::Idle)
216    }
217
218    pub fn is_error(&self) -> bool {
219        matches!(self, PendingToolUseStatus::Error(_))
220    }
221}