1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::MessageId;
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30#[derive(Debug, Clone)]
31pub enum ToolUseStatus {
32 NeedsConfirmation,
33 Pending,
34 Running,
35 Finished(SharedString),
36 Error(SharedString),
37}
38
39impl ToolUseStatus {
40 pub fn text(&self) -> SharedString {
41 match self {
42 ToolUseStatus::NeedsConfirmation => "".into(),
43 ToolUseStatus::Pending => "".into(),
44 ToolUseStatus::Running => "".into(),
45 ToolUseStatus::Finished(out) => out.clone(),
46 ToolUseStatus::Error(out) => out.clone(),
47 }
48 }
49}
50
51pub struct ToolUseState {
52 tools: Entity<ToolWorkingSet>,
53 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
54 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
55 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
56 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
57}
58
59pub const USING_TOOL_MARKER: &str = "<using_tool>";
60
61impl ToolUseState {
62 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
63 Self {
64 tools,
65 tool_uses_by_assistant_message: HashMap::default(),
66 tool_uses_by_user_message: HashMap::default(),
67 tool_results: HashMap::default(),
68 pending_tool_uses_by_id: HashMap::default(),
69 }
70 }
71
72 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
73 ///
74 /// Accepts a function to filter the tools that should be used to populate the state.
75 pub fn from_serialized_messages(
76 tools: Entity<ToolWorkingSet>,
77 messages: &[SerializedMessage],
78 mut filter_by_tool_name: impl FnMut(&str) -> bool,
79 ) -> Self {
80 let mut this = Self::new(tools);
81 let mut tool_names_by_id = HashMap::default();
82
83 for message in messages {
84 match message.role {
85 Role::Assistant => {
86 if !message.tool_uses.is_empty() {
87 let tool_uses = message
88 .tool_uses
89 .iter()
90 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
91 .map(|tool_use| LanguageModelToolUse {
92 id: tool_use.id.clone(),
93 name: tool_use.name.clone().into(),
94 input: tool_use.input.clone(),
95 })
96 .collect::<Vec<_>>();
97
98 tool_names_by_id.extend(
99 tool_uses
100 .iter()
101 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
102 );
103
104 this.tool_uses_by_assistant_message
105 .insert(message.id, tool_uses);
106 }
107 }
108 Role::User => {
109 if !message.tool_results.is_empty() {
110 let tool_uses_by_user_message = this
111 .tool_uses_by_user_message
112 .entry(message.id)
113 .or_default();
114
115 for tool_result in &message.tool_results {
116 let tool_use_id = tool_result.tool_use_id.clone();
117 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
118 log::warn!("no tool name found for tool use: {tool_use_id:?}");
119 continue;
120 };
121
122 if !(filter_by_tool_name)(tool_use.as_ref()) {
123 continue;
124 }
125
126 tool_uses_by_user_message.push(tool_use_id.clone());
127 this.tool_results.insert(
128 tool_use_id.clone(),
129 LanguageModelToolResult {
130 tool_use_id,
131 tool_name: tool_use.clone(),
132 is_error: tool_result.is_error,
133 content: tool_result.content.clone(),
134 },
135 );
136 }
137 }
138 }
139 Role::System => {}
140 }
141 }
142
143 this
144 }
145
146 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
147 let mut pending_tools = Vec::new();
148 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
149 self.tool_results.insert(
150 tool_use_id.clone(),
151 LanguageModelToolResult {
152 tool_use_id,
153 tool_name: tool_use.name.clone(),
154 content: "Tool canceled by user".into(),
155 is_error: true,
156 },
157 );
158 pending_tools.push(tool_use.clone());
159 }
160 pending_tools
161 }
162
163 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164 self.pending_tool_uses_by_id.values().collect()
165 }
166
167 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169 return Vec::new();
170 };
171
172 let mut tool_uses = Vec::new();
173
174 for tool_use in tool_uses_for_message.iter() {
175 let tool_result = self.tool_results.get(&tool_use.id);
176
177 let status = (|| {
178 if let Some(tool_result) = tool_result {
179 return if tool_result.is_error {
180 ToolUseStatus::Error(tool_result.content.clone().into())
181 } else {
182 ToolUseStatus::Finished(tool_result.content.clone().into())
183 };
184 }
185
186 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
187 match pending_tool_use.status {
188 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
189 PendingToolUseStatus::NeedsConfirmation { .. } => {
190 ToolUseStatus::NeedsConfirmation
191 }
192 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
193 PendingToolUseStatus::Error(ref err) => {
194 ToolUseStatus::Error(err.clone().into())
195 }
196 }
197 } else {
198 ToolUseStatus::Pending
199 }
200 })();
201
202 let (icon, needs_confirmation) =
203 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
204 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
205 } else {
206 (IconName::Cog, false)
207 };
208
209 tool_uses.push(ToolUse {
210 id: tool_use.id.clone(),
211 name: tool_use.name.clone().into(),
212 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
213 input: tool_use.input.clone(),
214 status,
215 icon,
216 needs_confirmation,
217 })
218 }
219
220 tool_uses
221 }
222
223 pub fn tool_ui_label(
224 &self,
225 tool_name: &str,
226 input: &serde_json::Value,
227 cx: &App,
228 ) -> SharedString {
229 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
230 tool.ui_text(input).into()
231 } else {
232 format!("Unknown tool {tool_name:?}").into()
233 }
234 }
235
236 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
237 let empty = Vec::new();
238
239 self.tool_uses_by_user_message
240 .get(&message_id)
241 .unwrap_or(&empty)
242 .iter()
243 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
244 .collect()
245 }
246
247 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
248 self.tool_uses_by_user_message
249 .get(&message_id)
250 .map_or(false, |results| !results.is_empty())
251 }
252
253 pub fn tool_result(
254 &self,
255 tool_use_id: &LanguageModelToolUseId,
256 ) -> Option<&LanguageModelToolResult> {
257 self.tool_results.get(tool_use_id)
258 }
259
260 pub fn request_tool_use(
261 &mut self,
262 assistant_message_id: MessageId,
263 tool_use: LanguageModelToolUse,
264 cx: &App,
265 ) {
266 self.tool_uses_by_assistant_message
267 .entry(assistant_message_id)
268 .or_default()
269 .push(tool_use.clone());
270
271 // The tool use is being requested by the Assistant, so we want to
272 // attach the tool results to the next user message.
273 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
274 self.tool_uses_by_user_message
275 .entry(next_user_message_id)
276 .or_default()
277 .push(tool_use.id.clone());
278
279 self.pending_tool_uses_by_id.insert(
280 tool_use.id.clone(),
281 PendingToolUse {
282 assistant_message_id,
283 id: tool_use.id,
284 name: tool_use.name.clone(),
285 ui_text: self
286 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
287 .into(),
288 input: tool_use.input,
289 status: PendingToolUseStatus::Idle,
290 },
291 );
292 }
293
294 pub fn run_pending_tool(
295 &mut self,
296 tool_use_id: LanguageModelToolUseId,
297 ui_text: SharedString,
298 task: Task<()>,
299 ) {
300 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301 tool_use.ui_text = ui_text.into();
302 tool_use.status = PendingToolUseStatus::Running {
303 _task: task.shared(),
304 };
305 }
306 }
307
308 pub fn confirm_tool_use(
309 &mut self,
310 tool_use_id: LanguageModelToolUseId,
311 ui_text: impl Into<Arc<str>>,
312 input: serde_json::Value,
313 messages: Arc<Vec<LanguageModelRequestMessage>>,
314 tool: Arc<dyn Tool>,
315 ) {
316 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
317 let ui_text = ui_text.into();
318 tool_use.ui_text = ui_text.clone();
319 let confirmation = Confirmation {
320 tool_use_id,
321 input,
322 messages,
323 tool,
324 ui_text,
325 };
326 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
327 }
328 }
329
330 pub fn insert_tool_output(
331 &mut self,
332 tool_use_id: LanguageModelToolUseId,
333 tool_name: Arc<str>,
334 output: Result<String>,
335 cx: &App,
336 ) -> Option<PendingToolUse> {
337 telemetry::event!("Agent Tool Finished", tool_name, success = output.is_ok());
338
339 match output {
340 Ok(tool_result) => {
341 let model_registry = LanguageModelRegistry::read_global(cx);
342
343 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
344
345 // Protect from clearly large output
346 let tool_output_limit = model_registry
347 .default_model()
348 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
349 .unwrap_or(usize::MAX);
350
351 let tool_result = if tool_result.len() <= tool_output_limit {
352 tool_result
353 } else {
354 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
355
356 format!(
357 "Tool result too long. The first {} bytes:\n\n{}",
358 truncated.len(),
359 truncated
360 )
361 };
362
363 self.tool_results.insert(
364 tool_use_id.clone(),
365 LanguageModelToolResult {
366 tool_use_id: tool_use_id.clone(),
367 tool_name,
368 content: tool_result.into(),
369 is_error: false,
370 },
371 );
372 self.pending_tool_uses_by_id.remove(&tool_use_id)
373 }
374 Err(err) => {
375 self.tool_results.insert(
376 tool_use_id.clone(),
377 LanguageModelToolResult {
378 tool_use_id: tool_use_id.clone(),
379 tool_name,
380 content: err.to_string().into(),
381 is_error: true,
382 },
383 );
384
385 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
386 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
387 }
388
389 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
390 }
391 }
392 }
393
394 pub fn attach_tool_uses(
395 &self,
396 message_id: MessageId,
397 request_message: &mut LanguageModelRequestMessage,
398 ) {
399 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
400 let mut found_tool_use = false;
401
402 for tool_use in tool_uses {
403 if self.tool_results.contains_key(&tool_use.id) {
404 if !found_tool_use {
405 // The API fails if a message contains a tool use without any (non-whitespace) text around it
406 match request_message.content.last_mut() {
407 Some(MessageContent::Text(txt)) => {
408 if txt.is_empty() {
409 txt.push_str(USING_TOOL_MARKER);
410 }
411 }
412 None | Some(_) => {
413 request_message
414 .content
415 .push(MessageContent::Text(USING_TOOL_MARKER.into()));
416 }
417 };
418 }
419
420 found_tool_use = true;
421
422 // Do not send tool uses until they are completed
423 request_message
424 .content
425 .push(MessageContent::ToolUse(tool_use.clone()));
426 } else {
427 log::debug!(
428 "skipped tool use {:?} because it is still pending",
429 tool_use
430 );
431 }
432 }
433 }
434 }
435
436 pub fn attach_tool_results(
437 &self,
438 message_id: MessageId,
439 request_message: &mut LanguageModelRequestMessage,
440 ) {
441 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
442 for tool_use_id in tool_uses {
443 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
444 request_message.content.push(MessageContent::ToolResult(
445 LanguageModelToolResult {
446 tool_use_id: tool_use_id.clone(),
447 tool_name: tool_result.tool_name.clone(),
448 is_error: tool_result.is_error,
449 content: if tool_result.content.is_empty() {
450 // Surprisingly, the API fails if we return an empty string here.
451 // It thinks we are sending a tool use without a tool result.
452 "<Tool returned an empty string>".into()
453 } else {
454 tool_result.content.clone()
455 },
456 },
457 ));
458 }
459 }
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
465pub struct PendingToolUse {
466 pub id: LanguageModelToolUseId,
467 /// The ID of the Assistant message in which the tool use was requested.
468 #[allow(unused)]
469 pub assistant_message_id: MessageId,
470 pub name: Arc<str>,
471 pub ui_text: Arc<str>,
472 pub input: serde_json::Value,
473 pub status: PendingToolUseStatus,
474}
475
476#[derive(Debug, Clone)]
477pub struct Confirmation {
478 pub tool_use_id: LanguageModelToolUseId,
479 pub input: serde_json::Value,
480 pub ui_text: Arc<str>,
481 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
482 pub tool: Arc<dyn Tool>,
483}
484
485#[derive(Debug, Clone)]
486pub enum PendingToolUseStatus {
487 Idle,
488 NeedsConfirmation(Arc<Confirmation>),
489 Running { _task: Shared<Task<()>> },
490 Error(#[allow(unused)] Arc<str>),
491}
492
493impl PendingToolUseStatus {
494 pub fn is_idle(&self) -> bool {
495 matches!(self, PendingToolUseStatus::Idle)
496 }
497
498 pub fn is_error(&self) -> bool {
499 matches!(self, PendingToolUseStatus::Error(_))
500 }
501
502 pub fn needs_confirmation(&self) -> bool {
503 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
504 }
505}