1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashMap;
5use futures::future::Shared;
6use futures::FutureExt as _;
7use gpui::{SharedString, Task};
8use language_model::{
9 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
10 LanguageModelToolUseId, MessageContent, Role,
11};
12
13use crate::thread::MessageId;
14use crate::thread_store::SerializedMessage;
15
16#[derive(Debug)]
17pub struct ToolUse {
18 pub id: LanguageModelToolUseId,
19 pub name: SharedString,
20 pub status: ToolUseStatus,
21 pub input: serde_json::Value,
22}
23
24#[derive(Debug, Clone)]
25pub enum ToolUseStatus {
26 Pending,
27 Running,
28 Finished(SharedString),
29 Error(SharedString),
30}
31
32pub struct ToolUseState {
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
35 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
36 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
37}
38
39impl ToolUseState {
40 pub fn new() -> Self {
41 Self {
42 tool_uses_by_assistant_message: HashMap::default(),
43 tool_uses_by_user_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 }
47 }
48
49 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
50 ///
51 /// Accepts a function to filter the tools that should be used to populate the state.
52 pub fn from_serialized_messages(
53 messages: &[SerializedMessage],
54 mut filter_by_tool_name: impl FnMut(&str) -> bool,
55 ) -> Self {
56 let mut this = Self::new();
57 let mut tool_names_by_id = HashMap::default();
58
59 for message in messages {
60 match message.role {
61 Role::Assistant => {
62 if !message.tool_uses.is_empty() {
63 let tool_uses = message
64 .tool_uses
65 .iter()
66 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
67 .map(|tool_use| LanguageModelToolUse {
68 id: tool_use.id.clone(),
69 name: tool_use.name.clone().into(),
70 input: tool_use.input.clone(),
71 })
72 .collect::<Vec<_>>();
73
74 tool_names_by_id.extend(
75 tool_uses
76 .iter()
77 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
78 );
79
80 this.tool_uses_by_assistant_message
81 .insert(message.id, tool_uses);
82 }
83 }
84 Role::User => {
85 if !message.tool_results.is_empty() {
86 let tool_uses_by_user_message = this
87 .tool_uses_by_user_message
88 .entry(message.id)
89 .or_default();
90
91 for tool_result in &message.tool_results {
92 let tool_use_id = tool_result.tool_use_id.clone();
93 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
94 log::warn!("no tool name found for tool use: {tool_use_id:?}");
95 continue;
96 };
97
98 if !(filter_by_tool_name)(tool_use.as_ref()) {
99 continue;
100 }
101
102 tool_uses_by_user_message.push(tool_use_id.clone());
103 this.tool_results.insert(
104 tool_use_id.clone(),
105 LanguageModelToolResult {
106 tool_use_id,
107 is_error: tool_result.is_error,
108 content: tool_result.content.clone(),
109 },
110 );
111 }
112 }
113 }
114 Role::System => {}
115 }
116 }
117
118 this
119 }
120
121 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
122 let mut pending_tools = Vec::new();
123 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
124 self.tool_results.insert(
125 tool_use_id.clone(),
126 LanguageModelToolResult {
127 tool_use_id,
128 content: "Tool canceled by user".into(),
129 is_error: true,
130 },
131 );
132 pending_tools.push(tool_use.clone());
133 }
134 pending_tools
135 }
136
137 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
138 self.pending_tool_uses_by_id.values().collect()
139 }
140
141 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
142 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
143 return Vec::new();
144 };
145
146 let mut tool_uses = Vec::new();
147
148 for tool_use in tool_uses_for_message.iter() {
149 let tool_result = self.tool_results.get(&tool_use.id);
150
151 let status = (|| {
152 if let Some(tool_result) = tool_result {
153 return if tool_result.is_error {
154 ToolUseStatus::Error(tool_result.content.clone().into())
155 } else {
156 ToolUseStatus::Finished(tool_result.content.clone().into())
157 };
158 }
159
160 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
161 return match pending_tool_use.status {
162 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
163 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
164 PendingToolUseStatus::Error(ref err) => {
165 ToolUseStatus::Error(err.clone().into())
166 }
167 };
168 }
169
170 ToolUseStatus::Pending
171 })();
172
173 tool_uses.push(ToolUse {
174 id: tool_use.id.clone(),
175 name: tool_use.name.clone().into(),
176 input: tool_use.input.clone(),
177 status,
178 })
179 }
180
181 tool_uses
182 }
183
184 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
185 let empty = Vec::new();
186
187 self.tool_uses_by_user_message
188 .get(&message_id)
189 .unwrap_or(&empty)
190 .iter()
191 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
192 .collect()
193 }
194
195 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
196 self.tool_uses_by_user_message
197 .get(&message_id)
198 .map_or(false, |results| !results.is_empty())
199 }
200
201 pub fn tool_result(
202 &self,
203 tool_use_id: &LanguageModelToolUseId,
204 ) -> Option<&LanguageModelToolResult> {
205 self.tool_results.get(tool_use_id)
206 }
207
208 pub fn request_tool_use(
209 &mut self,
210 assistant_message_id: MessageId,
211 tool_use: LanguageModelToolUse,
212 ) {
213 self.tool_uses_by_assistant_message
214 .entry(assistant_message_id)
215 .or_default()
216 .push(tool_use.clone());
217
218 // The tool use is being requested by the Assistant, so we want to
219 // attach the tool results to the next user message.
220 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
221 self.tool_uses_by_user_message
222 .entry(next_user_message_id)
223 .or_default()
224 .push(tool_use.id.clone());
225
226 self.pending_tool_uses_by_id.insert(
227 tool_use.id.clone(),
228 PendingToolUse {
229 assistant_message_id,
230 id: tool_use.id,
231 name: tool_use.name,
232 input: tool_use.input,
233 status: PendingToolUseStatus::Idle,
234 },
235 );
236 }
237
238 pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
239 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
240 tool_use.status = PendingToolUseStatus::Running {
241 _task: task.shared(),
242 };
243 }
244 }
245
246 pub fn insert_tool_output(
247 &mut self,
248 tool_use_id: LanguageModelToolUseId,
249 output: Result<String>,
250 ) -> Option<PendingToolUse> {
251 match output {
252 Ok(tool_result) => {
253 self.tool_results.insert(
254 tool_use_id.clone(),
255 LanguageModelToolResult {
256 tool_use_id: tool_use_id.clone(),
257 content: tool_result.into(),
258 is_error: false,
259 },
260 );
261 self.pending_tool_uses_by_id.remove(&tool_use_id)
262 }
263 Err(err) => {
264 self.tool_results.insert(
265 tool_use_id.clone(),
266 LanguageModelToolResult {
267 tool_use_id: tool_use_id.clone(),
268 content: err.to_string().into(),
269 is_error: true,
270 },
271 );
272
273 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
274 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
275 }
276
277 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
278 }
279 }
280 }
281
282 pub fn attach_tool_uses(
283 &self,
284 message_id: MessageId,
285 request_message: &mut LanguageModelRequestMessage,
286 ) {
287 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
288 for tool_use in tool_uses {
289 if self.tool_results.contains_key(&tool_use.id) {
290 // Do not send tool uses until they are completed
291 request_message
292 .content
293 .push(MessageContent::ToolUse(tool_use.clone()));
294 } else {
295 log::debug!(
296 "skipped tool use {:?} because it is still pending",
297 tool_use
298 );
299 }
300 }
301 }
302 }
303
304 pub fn attach_tool_results(
305 &self,
306 message_id: MessageId,
307 request_message: &mut LanguageModelRequestMessage,
308 ) {
309 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
310 for tool_use_id in tool_uses {
311 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
312 request_message.content.push(MessageContent::ToolResult(
313 LanguageModelToolResult {
314 tool_use_id: tool_use_id.clone(),
315 is_error: tool_result.is_error,
316 content: if tool_result.content.is_empty() {
317 // Surprisingly, the API fails if we return an empty string here.
318 // It thinks we are sending a tool use without a tool result.
319 "<Tool returned an empty string>".into()
320 } else {
321 tool_result.content.clone()
322 },
323 },
324 ));
325 }
326 }
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
332pub struct PendingToolUse {
333 pub id: LanguageModelToolUseId,
334 /// The ID of the Assistant message in which the tool use was requested.
335 #[allow(unused)]
336 pub assistant_message_id: MessageId,
337 pub name: Arc<str>,
338 pub input: serde_json::Value,
339 pub status: PendingToolUseStatus,
340}
341
342#[derive(Debug, Clone)]
343pub enum PendingToolUseStatus {
344 Idle,
345 Running { _task: Shared<Task<()>> },
346 Error(#[allow(unused)] Arc<str>),
347}
348
349impl PendingToolUseStatus {
350 pub fn is_idle(&self) -> bool {
351 matches!(self, PendingToolUseStatus::Idle)
352 }
353
354 pub fn is_error(&self) -> bool {
355 matches!(self, PendingToolUseStatus::Error(_))
356 }
357}