1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::FutureExt as _;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13
14use crate::thread::MessageId;
15use crate::thread_store::SerializedMessage;
16
17#[derive(Debug)]
18pub struct ToolUse {
19 pub id: LanguageModelToolUseId,
20 pub name: SharedString,
21 pub ui_text: SharedString,
22 pub status: ToolUseStatus,
23 pub input: serde_json::Value,
24}
25
26#[derive(Debug, Clone)]
27pub enum ToolUseStatus {
28 Pending,
29 Running,
30 Finished(SharedString),
31 Error(SharedString),
32}
33
34pub struct ToolUseState {
35 tools: Arc<ToolWorkingSet>,
36 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
37 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
38 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
39 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_uses_by_user_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 }
51 }
52
53 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
54 ///
55 /// Accepts a function to filter the tools that should be used to populate the state.
56 pub fn from_serialized_messages(
57 tools: Arc<ToolWorkingSet>,
58 messages: &[SerializedMessage],
59 mut filter_by_tool_name: impl FnMut(&str) -> bool,
60 ) -> Self {
61 let mut this = Self::new(tools);
62 let mut tool_names_by_id = HashMap::default();
63
64 for message in messages {
65 match message.role {
66 Role::Assistant => {
67 if !message.tool_uses.is_empty() {
68 let tool_uses = message
69 .tool_uses
70 .iter()
71 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
72 .map(|tool_use| LanguageModelToolUse {
73 id: tool_use.id.clone(),
74 name: tool_use.name.clone().into(),
75 input: tool_use.input.clone(),
76 })
77 .collect::<Vec<_>>();
78
79 tool_names_by_id.extend(
80 tool_uses
81 .iter()
82 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
83 );
84
85 this.tool_uses_by_assistant_message
86 .insert(message.id, tool_uses);
87 }
88 }
89 Role::User => {
90 if !message.tool_results.is_empty() {
91 let tool_uses_by_user_message = this
92 .tool_uses_by_user_message
93 .entry(message.id)
94 .or_default();
95
96 for tool_result in &message.tool_results {
97 let tool_use_id = tool_result.tool_use_id.clone();
98 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
99 log::warn!("no tool name found for tool use: {tool_use_id:?}");
100 continue;
101 };
102
103 if !(filter_by_tool_name)(tool_use.as_ref()) {
104 continue;
105 }
106
107 tool_uses_by_user_message.push(tool_use_id.clone());
108 this.tool_results.insert(
109 tool_use_id.clone(),
110 LanguageModelToolResult {
111 tool_use_id,
112 is_error: tool_result.is_error,
113 content: tool_result.content.clone(),
114 },
115 );
116 }
117 }
118 }
119 Role::System => {}
120 }
121 }
122
123 this
124 }
125
126 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
127 let mut pending_tools = Vec::new();
128 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
129 self.tool_results.insert(
130 tool_use_id.clone(),
131 LanguageModelToolResult {
132 tool_use_id,
133 content: "Tool canceled by user".into(),
134 is_error: true,
135 },
136 );
137 pending_tools.push(tool_use.clone());
138 }
139 pending_tools
140 }
141
142 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
143 self.pending_tool_uses_by_id.values().collect()
144 }
145
146 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
147 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
148 return Vec::new();
149 };
150
151 let mut tool_uses = Vec::new();
152
153 for tool_use in tool_uses_for_message.iter() {
154 let tool_result = self.tool_results.get(&tool_use.id);
155
156 let status = (|| {
157 if let Some(tool_result) = tool_result {
158 return if tool_result.is_error {
159 ToolUseStatus::Error(tool_result.content.clone().into())
160 } else {
161 ToolUseStatus::Finished(tool_result.content.clone().into())
162 };
163 }
164
165 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
166 return match pending_tool_use.status {
167 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
168 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
169 PendingToolUseStatus::Error(ref err) => {
170 ToolUseStatus::Error(err.clone().into())
171 }
172 };
173 }
174
175 ToolUseStatus::Pending
176 })();
177
178 tool_uses.push(ToolUse {
179 id: tool_use.id.clone(),
180 name: tool_use.name.clone().into(),
181 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
182 input: tool_use.input.clone(),
183 status,
184 })
185 }
186
187 tool_uses
188 }
189
190 pub fn tool_ui_label(
191 &self,
192 tool_name: &str,
193 input: &serde_json::Value,
194 cx: &App,
195 ) -> SharedString {
196 if let Some(tool) = self.tools.tool(tool_name, cx) {
197 tool.ui_text(input).into()
198 } else {
199 "Unknown tool".into()
200 }
201 }
202
203 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
204 let empty = Vec::new();
205
206 self.tool_uses_by_user_message
207 .get(&message_id)
208 .unwrap_or(&empty)
209 .iter()
210 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
211 .collect()
212 }
213
214 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
215 self.tool_uses_by_user_message
216 .get(&message_id)
217 .map_or(false, |results| !results.is_empty())
218 }
219
220 pub fn tool_result(
221 &self,
222 tool_use_id: &LanguageModelToolUseId,
223 ) -> Option<&LanguageModelToolResult> {
224 self.tool_results.get(tool_use_id)
225 }
226
227 pub fn request_tool_use(
228 &mut self,
229 assistant_message_id: MessageId,
230 tool_use: LanguageModelToolUse,
231 cx: &App,
232 ) {
233 self.tool_uses_by_assistant_message
234 .entry(assistant_message_id)
235 .or_default()
236 .push(tool_use.clone());
237
238 // The tool use is being requested by the Assistant, so we want to
239 // attach the tool results to the next user message.
240 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
241 self.tool_uses_by_user_message
242 .entry(next_user_message_id)
243 .or_default()
244 .push(tool_use.id.clone());
245
246 self.pending_tool_uses_by_id.insert(
247 tool_use.id.clone(),
248 PendingToolUse {
249 assistant_message_id,
250 id: tool_use.id,
251 name: tool_use.name.clone(),
252 ui_text: self
253 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
254 .into(),
255 input: tool_use.input,
256 status: PendingToolUseStatus::Idle,
257 },
258 );
259 }
260
261 pub fn run_pending_tool(
262 &mut self,
263 tool_use_id: LanguageModelToolUseId,
264 ui_text: SharedString,
265 task: Task<()>,
266 ) {
267 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
268 tool_use.ui_text = ui_text.into();
269 tool_use.status = PendingToolUseStatus::Running {
270 _task: task.shared(),
271 };
272 }
273 }
274
275 pub fn insert_tool_output(
276 &mut self,
277 tool_use_id: LanguageModelToolUseId,
278 output: Result<String>,
279 ) -> Option<PendingToolUse> {
280 match output {
281 Ok(tool_result) => {
282 self.tool_results.insert(
283 tool_use_id.clone(),
284 LanguageModelToolResult {
285 tool_use_id: tool_use_id.clone(),
286 content: tool_result.into(),
287 is_error: false,
288 },
289 );
290 self.pending_tool_uses_by_id.remove(&tool_use_id)
291 }
292 Err(err) => {
293 self.tool_results.insert(
294 tool_use_id.clone(),
295 LanguageModelToolResult {
296 tool_use_id: tool_use_id.clone(),
297 content: err.to_string().into(),
298 is_error: true,
299 },
300 );
301
302 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
303 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
304 }
305
306 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
307 }
308 }
309 }
310
311 pub fn attach_tool_uses(
312 &self,
313 message_id: MessageId,
314 request_message: &mut LanguageModelRequestMessage,
315 ) {
316 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
317 for tool_use in tool_uses {
318 if self.tool_results.contains_key(&tool_use.id) {
319 // Do not send tool uses until they are completed
320 request_message
321 .content
322 .push(MessageContent::ToolUse(tool_use.clone()));
323 } else {
324 log::debug!(
325 "skipped tool use {:?} because it is still pending",
326 tool_use
327 );
328 }
329 }
330 }
331 }
332
333 pub fn attach_tool_results(
334 &self,
335 message_id: MessageId,
336 request_message: &mut LanguageModelRequestMessage,
337 ) {
338 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
339 for tool_use_id in tool_uses {
340 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
341 request_message.content.push(MessageContent::ToolResult(
342 LanguageModelToolResult {
343 tool_use_id: tool_use_id.clone(),
344 is_error: tool_result.is_error,
345 content: if tool_result.content.is_empty() {
346 // Surprisingly, the API fails if we return an empty string here.
347 // It thinks we are sending a tool use without a tool result.
348 "<Tool returned an empty string>".into()
349 } else {
350 tool_result.content.clone()
351 },
352 },
353 ));
354 }
355 }
356 }
357 }
358}
359
360#[derive(Debug, Clone)]
361pub struct PendingToolUse {
362 pub id: LanguageModelToolUseId,
363 /// The ID of the Assistant message in which the tool use was requested.
364 #[allow(unused)]
365 pub assistant_message_id: MessageId,
366 pub name: Arc<str>,
367 pub ui_text: Arc<str>,
368 pub input: serde_json::Value,
369 pub status: PendingToolUseStatus,
370}
371
372#[derive(Debug, Clone)]
373pub enum PendingToolUseStatus {
374 Idle,
375 Running { _task: Shared<Task<()>> },
376 Error(#[allow(unused)] Arc<str>),
377}
378
379impl PendingToolUseStatus {
380 pub fn is_idle(&self) -> bool {
381 matches!(self, PendingToolUseStatus::Idle)
382 }
383
384 pub fn is_error(&self) -> bool {
385 matches!(self, PendingToolUseStatus::Error(_))
386 }
387}