1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25 pub icon: ui::IconName,
26 pub needs_confirmation: bool,
27}
28
29#[derive(Debug, Clone)]
30pub enum ToolUseStatus {
31 NeedsConfirmation,
32 Pending,
33 Running,
34 Finished(SharedString),
35 Error(SharedString),
36}
37
38pub struct ToolUseState {
39 tools: Arc<ToolWorkingSet>,
40 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
41 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
42 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
43 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
44}
45
46impl ToolUseState {
47 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
48 Self {
49 tools,
50 tool_uses_by_assistant_message: HashMap::default(),
51 tool_uses_by_user_message: HashMap::default(),
52 tool_results: HashMap::default(),
53 pending_tool_uses_by_id: HashMap::default(),
54 }
55 }
56
57 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
58 ///
59 /// Accepts a function to filter the tools that should be used to populate the state.
60 pub fn from_serialized_messages(
61 tools: Arc<ToolWorkingSet>,
62 messages: &[SerializedMessage],
63 mut filter_by_tool_name: impl FnMut(&str) -> bool,
64 ) -> Self {
65 let mut this = Self::new(tools);
66 let mut tool_names_by_id = HashMap::default();
67
68 for message in messages {
69 match message.role {
70 Role::Assistant => {
71 if !message.tool_uses.is_empty() {
72 let tool_uses = message
73 .tool_uses
74 .iter()
75 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
76 .map(|tool_use| LanguageModelToolUse {
77 id: tool_use.id.clone(),
78 name: tool_use.name.clone().into(),
79 input: tool_use.input.clone(),
80 })
81 .collect::<Vec<_>>();
82
83 tool_names_by_id.extend(
84 tool_uses
85 .iter()
86 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
87 );
88
89 this.tool_uses_by_assistant_message
90 .insert(message.id, tool_uses);
91 }
92 }
93 Role::User => {
94 if !message.tool_results.is_empty() {
95 let tool_uses_by_user_message = this
96 .tool_uses_by_user_message
97 .entry(message.id)
98 .or_default();
99
100 for tool_result in &message.tool_results {
101 let tool_use_id = tool_result.tool_use_id.clone();
102 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
103 log::warn!("no tool name found for tool use: {tool_use_id:?}");
104 continue;
105 };
106
107 if !(filter_by_tool_name)(tool_use.as_ref()) {
108 continue;
109 }
110
111 tool_uses_by_user_message.push(tool_use_id.clone());
112 this.tool_results.insert(
113 tool_use_id.clone(),
114 LanguageModelToolResult {
115 tool_use_id,
116 tool_name: tool_use.clone(),
117 is_error: tool_result.is_error,
118 content: tool_result.content.clone(),
119 },
120 );
121 }
122 }
123 }
124 Role::System => {}
125 }
126 }
127
128 this
129 }
130
131 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
132 let mut pending_tools = Vec::new();
133 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
134 self.tool_results.insert(
135 tool_use_id.clone(),
136 LanguageModelToolResult {
137 tool_use_id,
138 tool_name: tool_use.name.clone(),
139 content: "Tool canceled by user".into(),
140 is_error: true,
141 },
142 );
143 pending_tools.push(tool_use.clone());
144 }
145 pending_tools
146 }
147
148 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
149 self.pending_tool_uses_by_id.values().collect()
150 }
151
152 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
153 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
154 return Vec::new();
155 };
156
157 let mut tool_uses = Vec::new();
158
159 for tool_use in tool_uses_for_message.iter() {
160 let tool_result = self.tool_results.get(&tool_use.id);
161
162 let status = (|| {
163 if let Some(tool_result) = tool_result {
164 return if tool_result.is_error {
165 ToolUseStatus::Error(tool_result.content.clone().into())
166 } else {
167 ToolUseStatus::Finished(tool_result.content.clone().into())
168 };
169 }
170
171 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
172 match pending_tool_use.status {
173 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
174 PendingToolUseStatus::NeedsConfirmation { .. } => {
175 ToolUseStatus::NeedsConfirmation
176 }
177 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
178 PendingToolUseStatus::Error(ref err) => {
179 ToolUseStatus::Error(err.clone().into())
180 }
181 }
182 } else {
183 ToolUseStatus::Pending
184 }
185 })();
186
187 let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
188 {
189 (tool.icon(), tool.needs_confirmation())
190 } else {
191 (IconName::Cog, false)
192 };
193
194 tool_uses.push(ToolUse {
195 id: tool_use.id.clone(),
196 name: tool_use.name.clone().into(),
197 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
198 input: tool_use.input.clone(),
199 status,
200 icon,
201 needs_confirmation,
202 })
203 }
204
205 tool_uses
206 }
207
208 pub fn tool_ui_label(
209 &self,
210 tool_name: &str,
211 input: &serde_json::Value,
212 cx: &App,
213 ) -> SharedString {
214 if let Some(tool) = self.tools.tool(tool_name, cx) {
215 tool.ui_text(input).into()
216 } else {
217 format!("Unknown tool {tool_name:?}").into()
218 }
219 }
220
221 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
222 let empty = Vec::new();
223
224 self.tool_uses_by_user_message
225 .get(&message_id)
226 .unwrap_or(&empty)
227 .iter()
228 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
229 .collect()
230 }
231
232 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
233 self.tool_uses_by_user_message
234 .get(&message_id)
235 .map_or(false, |results| !results.is_empty())
236 }
237
238 pub fn tool_result(
239 &self,
240 tool_use_id: &LanguageModelToolUseId,
241 ) -> Option<&LanguageModelToolResult> {
242 self.tool_results.get(tool_use_id)
243 }
244
245 pub fn request_tool_use(
246 &mut self,
247 assistant_message_id: MessageId,
248 tool_use: LanguageModelToolUse,
249 cx: &App,
250 ) {
251 self.tool_uses_by_assistant_message
252 .entry(assistant_message_id)
253 .or_default()
254 .push(tool_use.clone());
255
256 // The tool use is being requested by the Assistant, so we want to
257 // attach the tool results to the next user message.
258 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
259 self.tool_uses_by_user_message
260 .entry(next_user_message_id)
261 .or_default()
262 .push(tool_use.id.clone());
263
264 self.pending_tool_uses_by_id.insert(
265 tool_use.id.clone(),
266 PendingToolUse {
267 assistant_message_id,
268 id: tool_use.id,
269 name: tool_use.name.clone(),
270 ui_text: self
271 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
272 .into(),
273 input: tool_use.input,
274 status: PendingToolUseStatus::Idle,
275 },
276 );
277 }
278
279 pub fn run_pending_tool(
280 &mut self,
281 tool_use_id: LanguageModelToolUseId,
282 ui_text: SharedString,
283 task: Task<()>,
284 ) {
285 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
286 tool_use.ui_text = ui_text.into();
287 tool_use.status = PendingToolUseStatus::Running {
288 _task: task.shared(),
289 };
290 }
291 }
292
293 pub fn confirm_tool_use(
294 &mut self,
295 tool_use_id: LanguageModelToolUseId,
296 ui_text: impl Into<Arc<str>>,
297 input: serde_json::Value,
298 messages: Arc<Vec<LanguageModelRequestMessage>>,
299 tool: Arc<dyn Tool>,
300 ) {
301 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
302 let ui_text = ui_text.into();
303 tool_use.ui_text = ui_text.clone();
304 let confirmation = Confirmation {
305 tool_use_id,
306 input,
307 messages,
308 tool,
309 ui_text,
310 };
311 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
312 }
313 }
314
315 pub fn insert_tool_output(
316 &mut self,
317 tool_use_id: LanguageModelToolUseId,
318 tool_name: Arc<str>,
319 output: Result<String>,
320 ) -> Option<PendingToolUse> {
321 match output {
322 Ok(tool_result) => {
323 self.tool_results.insert(
324 tool_use_id.clone(),
325 LanguageModelToolResult {
326 tool_use_id: tool_use_id.clone(),
327 tool_name,
328 content: tool_result.into(),
329 is_error: false,
330 },
331 );
332 self.pending_tool_uses_by_id.remove(&tool_use_id)
333 }
334 Err(err) => {
335 self.tool_results.insert(
336 tool_use_id.clone(),
337 LanguageModelToolResult {
338 tool_use_id: tool_use_id.clone(),
339 tool_name,
340 content: err.to_string().into(),
341 is_error: true,
342 },
343 );
344
345 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
346 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
347 }
348
349 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
350 }
351 }
352 }
353
354 pub fn attach_tool_uses(
355 &self,
356 message_id: MessageId,
357 request_message: &mut LanguageModelRequestMessage,
358 ) {
359 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
360 for tool_use in tool_uses {
361 if self.tool_results.contains_key(&tool_use.id) {
362 // Do not send tool uses until they are completed
363 request_message
364 .content
365 .push(MessageContent::ToolUse(tool_use.clone()));
366 } else {
367 log::debug!(
368 "skipped tool use {:?} because it is still pending",
369 tool_use
370 );
371 }
372 }
373 }
374 }
375
376 pub fn attach_tool_results(
377 &self,
378 message_id: MessageId,
379 request_message: &mut LanguageModelRequestMessage,
380 ) {
381 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
382 for tool_use_id in tool_uses {
383 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
384 request_message.content.push(MessageContent::ToolResult(
385 LanguageModelToolResult {
386 tool_use_id: tool_use_id.clone(),
387 tool_name: tool_result.tool_name.clone(),
388 is_error: tool_result.is_error,
389 content: if tool_result.content.is_empty() {
390 // Surprisingly, the API fails if we return an empty string here.
391 // It thinks we are sending a tool use without a tool result.
392 "<Tool returned an empty string>".into()
393 } else {
394 tool_result.content.clone()
395 },
396 },
397 ));
398 }
399 }
400 }
401 }
402}
403
404#[derive(Debug, Clone)]
405pub struct PendingToolUse {
406 pub id: LanguageModelToolUseId,
407 /// The ID of the Assistant message in which the tool use was requested.
408 #[allow(unused)]
409 pub assistant_message_id: MessageId,
410 pub name: Arc<str>,
411 pub ui_text: Arc<str>,
412 pub input: serde_json::Value,
413 pub status: PendingToolUseStatus,
414}
415
416#[derive(Debug, Clone)]
417pub struct Confirmation {
418 pub tool_use_id: LanguageModelToolUseId,
419 pub input: serde_json::Value,
420 pub ui_text: Arc<str>,
421 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
422 pub tool: Arc<dyn Tool>,
423}
424
425#[derive(Debug, Clone)]
426pub enum PendingToolUseStatus {
427 Idle,
428 NeedsConfirmation(Arc<Confirmation>),
429 Running { _task: Shared<Task<()>> },
430 Error(#[allow(unused)] Arc<str>),
431}
432
433impl PendingToolUseStatus {
434 pub fn is_idle(&self) -> bool {
435 matches!(self, PendingToolUseStatus::Idle)
436 }
437
438 pub fn is_error(&self) -> bool {
439 matches!(self, PendingToolUseStatus::Error(_))
440 }
441
442 pub fn needs_confirmation(&self) -> bool {
443 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
444 }
445}