1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25 pub icon: ui::IconName,
26 pub needs_confirmation: bool,
27}
28
29#[derive(Debug, Clone)]
30pub enum ToolUseStatus {
31 NeedsConfirmation,
32 Pending,
33 Running,
34 Finished(SharedString),
35 Error(SharedString),
36}
37
38pub struct ToolUseState {
39 tools: Arc<ToolWorkingSet>,
40 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
41 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
42 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
43 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
44}
45
46pub const USING_TOOL_MARKER: &str = "<using_tool>";
47
48impl ToolUseState {
49 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
50 Self {
51 tools,
52 tool_uses_by_assistant_message: HashMap::default(),
53 tool_uses_by_user_message: HashMap::default(),
54 tool_results: HashMap::default(),
55 pending_tool_uses_by_id: HashMap::default(),
56 }
57 }
58
59 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
60 ///
61 /// Accepts a function to filter the tools that should be used to populate the state.
62 pub fn from_serialized_messages(
63 tools: Arc<ToolWorkingSet>,
64 messages: &[SerializedMessage],
65 mut filter_by_tool_name: impl FnMut(&str) -> bool,
66 ) -> Self {
67 let mut this = Self::new(tools);
68 let mut tool_names_by_id = HashMap::default();
69
70 for message in messages {
71 match message.role {
72 Role::Assistant => {
73 if !message.tool_uses.is_empty() {
74 let tool_uses = message
75 .tool_uses
76 .iter()
77 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
78 .map(|tool_use| LanguageModelToolUse {
79 id: tool_use.id.clone(),
80 name: tool_use.name.clone().into(),
81 input: tool_use.input.clone(),
82 })
83 .collect::<Vec<_>>();
84
85 tool_names_by_id.extend(
86 tool_uses
87 .iter()
88 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
89 );
90
91 this.tool_uses_by_assistant_message
92 .insert(message.id, tool_uses);
93 }
94 }
95 Role::User => {
96 if !message.tool_results.is_empty() {
97 let tool_uses_by_user_message = this
98 .tool_uses_by_user_message
99 .entry(message.id)
100 .or_default();
101
102 for tool_result in &message.tool_results {
103 let tool_use_id = tool_result.tool_use_id.clone();
104 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
105 log::warn!("no tool name found for tool use: {tool_use_id:?}");
106 continue;
107 };
108
109 if !(filter_by_tool_name)(tool_use.as_ref()) {
110 continue;
111 }
112
113 tool_uses_by_user_message.push(tool_use_id.clone());
114 this.tool_results.insert(
115 tool_use_id.clone(),
116 LanguageModelToolResult {
117 tool_use_id,
118 tool_name: tool_use.clone(),
119 is_error: tool_result.is_error,
120 content: tool_result.content.clone(),
121 },
122 );
123 }
124 }
125 }
126 Role::System => {}
127 }
128 }
129
130 this
131 }
132
133 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
134 let mut pending_tools = Vec::new();
135 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
136 self.tool_results.insert(
137 tool_use_id.clone(),
138 LanguageModelToolResult {
139 tool_use_id,
140 tool_name: tool_use.name.clone(),
141 content: "Tool canceled by user".into(),
142 is_error: true,
143 },
144 );
145 pending_tools.push(tool_use.clone());
146 }
147 pending_tools
148 }
149
150 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
151 self.pending_tool_uses_by_id.values().collect()
152 }
153
154 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
155 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
156 return Vec::new();
157 };
158
159 let mut tool_uses = Vec::new();
160
161 for tool_use in tool_uses_for_message.iter() {
162 let tool_result = self.tool_results.get(&tool_use.id);
163
164 let status = (|| {
165 if let Some(tool_result) = tool_result {
166 return if tool_result.is_error {
167 ToolUseStatus::Error(tool_result.content.clone().into())
168 } else {
169 ToolUseStatus::Finished(tool_result.content.clone().into())
170 };
171 }
172
173 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
174 match pending_tool_use.status {
175 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
176 PendingToolUseStatus::NeedsConfirmation { .. } => {
177 ToolUseStatus::NeedsConfirmation
178 }
179 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
180 PendingToolUseStatus::Error(ref err) => {
181 ToolUseStatus::Error(err.clone().into())
182 }
183 }
184 } else {
185 ToolUseStatus::Pending
186 }
187 })();
188
189 let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
190 {
191 (tool.icon(), tool.needs_confirmation())
192 } else {
193 (IconName::Cog, false)
194 };
195
196 tool_uses.push(ToolUse {
197 id: tool_use.id.clone(),
198 name: tool_use.name.clone().into(),
199 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
200 input: tool_use.input.clone(),
201 status,
202 icon,
203 needs_confirmation,
204 })
205 }
206
207 tool_uses
208 }
209
210 pub fn tool_ui_label(
211 &self,
212 tool_name: &str,
213 input: &serde_json::Value,
214 cx: &App,
215 ) -> SharedString {
216 if let Some(tool) = self.tools.tool(tool_name, cx) {
217 tool.ui_text(input).into()
218 } else {
219 format!("Unknown tool {tool_name:?}").into()
220 }
221 }
222
223 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
224 let empty = Vec::new();
225
226 self.tool_uses_by_user_message
227 .get(&message_id)
228 .unwrap_or(&empty)
229 .iter()
230 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
231 .collect()
232 }
233
234 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
235 self.tool_uses_by_user_message
236 .get(&message_id)
237 .map_or(false, |results| !results.is_empty())
238 }
239
240 pub fn tool_result(
241 &self,
242 tool_use_id: &LanguageModelToolUseId,
243 ) -> Option<&LanguageModelToolResult> {
244 self.tool_results.get(tool_use_id)
245 }
246
247 pub fn request_tool_use(
248 &mut self,
249 assistant_message_id: MessageId,
250 tool_use: LanguageModelToolUse,
251 cx: &App,
252 ) {
253 self.tool_uses_by_assistant_message
254 .entry(assistant_message_id)
255 .or_default()
256 .push(tool_use.clone());
257
258 // The tool use is being requested by the Assistant, so we want to
259 // attach the tool results to the next user message.
260 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
261 self.tool_uses_by_user_message
262 .entry(next_user_message_id)
263 .or_default()
264 .push(tool_use.id.clone());
265
266 self.pending_tool_uses_by_id.insert(
267 tool_use.id.clone(),
268 PendingToolUse {
269 assistant_message_id,
270 id: tool_use.id,
271 name: tool_use.name.clone(),
272 ui_text: self
273 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
274 .into(),
275 input: tool_use.input,
276 status: PendingToolUseStatus::Idle,
277 },
278 );
279 }
280
281 pub fn run_pending_tool(
282 &mut self,
283 tool_use_id: LanguageModelToolUseId,
284 ui_text: SharedString,
285 task: Task<()>,
286 ) {
287 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
288 tool_use.ui_text = ui_text.into();
289 tool_use.status = PendingToolUseStatus::Running {
290 _task: task.shared(),
291 };
292 }
293 }
294
295 pub fn confirm_tool_use(
296 &mut self,
297 tool_use_id: LanguageModelToolUseId,
298 ui_text: impl Into<Arc<str>>,
299 input: serde_json::Value,
300 messages: Arc<Vec<LanguageModelRequestMessage>>,
301 tool: Arc<dyn Tool>,
302 ) {
303 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
304 let ui_text = ui_text.into();
305 tool_use.ui_text = ui_text.clone();
306 let confirmation = Confirmation {
307 tool_use_id,
308 input,
309 messages,
310 tool,
311 ui_text,
312 };
313 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
314 }
315 }
316
317 pub fn insert_tool_output(
318 &mut self,
319 tool_use_id: LanguageModelToolUseId,
320 tool_name: Arc<str>,
321 output: Result<String>,
322 ) -> Option<PendingToolUse> {
323 match output {
324 Ok(tool_result) => {
325 self.tool_results.insert(
326 tool_use_id.clone(),
327 LanguageModelToolResult {
328 tool_use_id: tool_use_id.clone(),
329 tool_name,
330 content: tool_result.into(),
331 is_error: false,
332 },
333 );
334 self.pending_tool_uses_by_id.remove(&tool_use_id)
335 }
336 Err(err) => {
337 self.tool_results.insert(
338 tool_use_id.clone(),
339 LanguageModelToolResult {
340 tool_use_id: tool_use_id.clone(),
341 tool_name,
342 content: err.to_string().into(),
343 is_error: true,
344 },
345 );
346
347 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
348 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
349 }
350
351 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
352 }
353 }
354 }
355
356 pub fn attach_tool_uses(
357 &self,
358 message_id: MessageId,
359 request_message: &mut LanguageModelRequestMessage,
360 ) {
361 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
362 let mut found_tool_use = false;
363
364 for tool_use in tool_uses {
365 if self.tool_results.contains_key(&tool_use.id) {
366 if !found_tool_use {
367 // The API fails if a message contains a tool use without any (non-whitespace) text around it
368 match request_message.content.last_mut() {
369 Some(MessageContent::Text(txt)) => {
370 if txt.is_empty() {
371 txt.push_str(USING_TOOL_MARKER);
372 }
373 }
374 None | Some(_) => {
375 request_message
376 .content
377 .push(MessageContent::Text(USING_TOOL_MARKER.into()));
378 }
379 };
380 }
381
382 found_tool_use = true;
383
384 // Do not send tool uses until they are completed
385 request_message
386 .content
387 .push(MessageContent::ToolUse(tool_use.clone()));
388 } else {
389 log::debug!(
390 "skipped tool use {:?} because it is still pending",
391 tool_use
392 );
393 }
394 }
395 }
396 }
397
398 pub fn attach_tool_results(
399 &self,
400 message_id: MessageId,
401 request_message: &mut LanguageModelRequestMessage,
402 ) {
403 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
404 for tool_use_id in tool_uses {
405 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
406 request_message.content.push(MessageContent::ToolResult(
407 LanguageModelToolResult {
408 tool_use_id: tool_use_id.clone(),
409 tool_name: tool_result.tool_name.clone(),
410 is_error: tool_result.is_error,
411 content: if tool_result.content.is_empty() {
412 // Surprisingly, the API fails if we return an empty string here.
413 // It thinks we are sending a tool use without a tool result.
414 "<Tool returned an empty string>".into()
415 } else {
416 tool_result.content.clone()
417 },
418 },
419 ));
420 }
421 }
422 }
423 }
424}
425
426#[derive(Debug, Clone)]
427pub struct PendingToolUse {
428 pub id: LanguageModelToolUseId,
429 /// The ID of the Assistant message in which the tool use was requested.
430 #[allow(unused)]
431 pub assistant_message_id: MessageId,
432 pub name: Arc<str>,
433 pub ui_text: Arc<str>,
434 pub input: serde_json::Value,
435 pub status: PendingToolUseStatus,
436}
437
438#[derive(Debug, Clone)]
439pub struct Confirmation {
440 pub tool_use_id: LanguageModelToolUseId,
441 pub input: serde_json::Value,
442 pub ui_text: Arc<str>,
443 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
444 pub tool: Arc<dyn Tool>,
445}
446
447#[derive(Debug, Clone)]
448pub enum PendingToolUseStatus {
449 Idle,
450 NeedsConfirmation(Arc<Confirmation>),
451 Running { _task: Shared<Task<()>> },
452 Error(#[allow(unused)] Arc<str>),
453}
454
455impl PendingToolUseStatus {
456 pub fn is_idle(&self) -> bool {
457 matches!(self, PendingToolUseStatus::Idle)
458 }
459
460 pub fn is_error(&self) -> bool {
461 matches!(self, PendingToolUseStatus::Error(_))
462 }
463
464 pub fn needs_confirmation(&self) -> bool {
465 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
466 }
467}