1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub const USING_TOOL_MARKER: &str = "<using_tool>";
31
32pub struct ToolUseState {
33 tools: Entity<ToolWorkingSet>,
34 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
35 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
36 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
37 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
38 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
39 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_uses_by_user_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 pub fn from_serialized_messages(
59 tools: Entity<ToolWorkingSet>,
60 messages: &[SerializedMessage],
61 mut filter_by_tool_name: impl FnMut(&str) -> bool,
62 ) -> Self {
63 let mut this = Self::new(tools);
64 let mut tool_names_by_id = HashMap::default();
65
66 for message in messages {
67 match message.role {
68 Role::Assistant => {
69 if !message.tool_uses.is_empty() {
70 let tool_uses = message
71 .tool_uses
72 .iter()
73 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
74 .map(|tool_use| LanguageModelToolUse {
75 id: tool_use.id.clone(),
76 name: tool_use.name.clone().into(),
77 input: tool_use.input.clone(),
78 })
79 .collect::<Vec<_>>();
80
81 tool_names_by_id.extend(
82 tool_uses
83 .iter()
84 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
85 );
86
87 this.tool_uses_by_assistant_message
88 .insert(message.id, tool_uses);
89 }
90 }
91 Role::User => {
92 if !message.tool_results.is_empty() {
93 let tool_uses_by_user_message = this
94 .tool_uses_by_user_message
95 .entry(message.id)
96 .or_default();
97
98 for tool_result in &message.tool_results {
99 let tool_use_id = tool_result.tool_use_id.clone();
100 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101 log::warn!("no tool name found for tool use: {tool_use_id:?}");
102 continue;
103 };
104
105 if !(filter_by_tool_name)(tool_use.as_ref()) {
106 continue;
107 }
108
109 tool_uses_by_user_message.push(tool_use_id.clone());
110 this.tool_results.insert(
111 tool_use_id.clone(),
112 LanguageModelToolResult {
113 tool_use_id,
114 tool_name: tool_use.clone(),
115 is_error: tool_result.is_error,
116 content: tool_result.content.clone(),
117 },
118 );
119 }
120 }
121 }
122 Role::System => {}
123 }
124 }
125
126 this
127 }
128
129 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130 let mut pending_tools = Vec::new();
131 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132 self.tool_results.insert(
133 tool_use_id.clone(),
134 LanguageModelToolResult {
135 tool_use_id,
136 tool_name: tool_use.name.clone(),
137 content: "Tool canceled by user".into(),
138 is_error: true,
139 },
140 );
141 pending_tools.push(tool_use.clone());
142 }
143 pending_tools
144 }
145
146 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
147 self.pending_tool_uses_by_id.values().collect()
148 }
149
150 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
151 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
152 return Vec::new();
153 };
154
155 let mut tool_uses = Vec::new();
156
157 for tool_use in tool_uses_for_message.iter() {
158 let tool_result = self.tool_results.get(&tool_use.id);
159
160 let status = (|| {
161 if let Some(tool_result) = tool_result {
162 return if tool_result.is_error {
163 ToolUseStatus::Error(tool_result.content.clone().into())
164 } else {
165 ToolUseStatus::Finished(tool_result.content.clone().into())
166 };
167 }
168
169 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
170 match pending_tool_use.status {
171 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
172 PendingToolUseStatus::NeedsConfirmation { .. } => {
173 ToolUseStatus::NeedsConfirmation
174 }
175 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
176 PendingToolUseStatus::Error(ref err) => {
177 ToolUseStatus::Error(err.clone().into())
178 }
179 }
180 } else {
181 ToolUseStatus::Pending
182 }
183 })();
184
185 let (icon, needs_confirmation) =
186 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
187 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
188 } else {
189 (IconName::Cog, false)
190 };
191
192 tool_uses.push(ToolUse {
193 id: tool_use.id.clone(),
194 name: tool_use.name.clone().into(),
195 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
196 input: tool_use.input.clone(),
197 status,
198 icon,
199 needs_confirmation,
200 })
201 }
202
203 tool_uses
204 }
205
206 pub fn tool_ui_label(
207 &self,
208 tool_name: &str,
209 input: &serde_json::Value,
210 cx: &App,
211 ) -> SharedString {
212 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
213 tool.ui_text(input).into()
214 } else {
215 format!("Unknown tool {tool_name:?}").into()
216 }
217 }
218
219 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
220 let empty = Vec::new();
221
222 self.tool_uses_by_user_message
223 .get(&message_id)
224 .unwrap_or(&empty)
225 .iter()
226 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
227 .collect()
228 }
229
230 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
231 self.tool_uses_by_user_message
232 .get(&message_id)
233 .map_or(false, |results| !results.is_empty())
234 }
235
236 pub fn tool_result(
237 &self,
238 tool_use_id: &LanguageModelToolUseId,
239 ) -> Option<&LanguageModelToolResult> {
240 self.tool_results.get(tool_use_id)
241 }
242
243 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
244 self.tool_result_cards.get(tool_use_id)
245 }
246
247 pub fn insert_tool_result_card(
248 &mut self,
249 tool_use_id: LanguageModelToolUseId,
250 card: AnyToolCard,
251 ) {
252 self.tool_result_cards.insert(tool_use_id, card);
253 }
254
255 pub fn request_tool_use(
256 &mut self,
257 assistant_message_id: MessageId,
258 tool_use: LanguageModelToolUse,
259 metadata: ToolUseMetadata,
260 cx: &App,
261 ) {
262 self.tool_uses_by_assistant_message
263 .entry(assistant_message_id)
264 .or_default()
265 .push(tool_use.clone());
266
267 self.tool_use_metadata_by_id
268 .insert(tool_use.id.clone(), metadata);
269
270 // The tool use is being requested by the Assistant, so we want to
271 // attach the tool results to the next user message.
272 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
273 self.tool_uses_by_user_message
274 .entry(next_user_message_id)
275 .or_default()
276 .push(tool_use.id.clone());
277
278 self.pending_tool_uses_by_id.insert(
279 tool_use.id.clone(),
280 PendingToolUse {
281 assistant_message_id,
282 id: tool_use.id,
283 name: tool_use.name.clone(),
284 ui_text: self
285 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
286 .into(),
287 input: tool_use.input,
288 status: PendingToolUseStatus::Idle,
289 },
290 );
291 }
292
293 pub fn run_pending_tool(
294 &mut self,
295 tool_use_id: LanguageModelToolUseId,
296 ui_text: SharedString,
297 task: Task<()>,
298 ) {
299 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
300 tool_use.ui_text = ui_text.into();
301 tool_use.status = PendingToolUseStatus::Running {
302 _task: task.shared(),
303 };
304 }
305 }
306
307 pub fn confirm_tool_use(
308 &mut self,
309 tool_use_id: LanguageModelToolUseId,
310 ui_text: impl Into<Arc<str>>,
311 input: serde_json::Value,
312 messages: Arc<Vec<LanguageModelRequestMessage>>,
313 tool: Arc<dyn Tool>,
314 ) {
315 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
316 let ui_text = ui_text.into();
317 tool_use.ui_text = ui_text.clone();
318 let confirmation = Confirmation {
319 tool_use_id,
320 input,
321 messages,
322 tool,
323 ui_text,
324 };
325 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
326 }
327 }
328
329 pub fn insert_tool_output(
330 &mut self,
331 tool_use_id: LanguageModelToolUseId,
332 tool_name: Arc<str>,
333 output: Result<String>,
334 cx: &App,
335 ) -> Option<PendingToolUse> {
336 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
337
338 telemetry::event!(
339 "Agent Tool Finished",
340 model = metadata
341 .as_ref()
342 .map(|metadata| metadata.model.telemetry_id()),
343 model_provider = metadata
344 .as_ref()
345 .map(|metadata| metadata.model.provider_id().to_string()),
346 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
347 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
348 tool_name,
349 success = output.is_ok()
350 );
351
352 match output {
353 Ok(tool_result) => {
354 let model_registry = LanguageModelRegistry::read_global(cx);
355
356 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
357
358 // Protect from clearly large output
359 let tool_output_limit = model_registry
360 .default_model()
361 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
362 .unwrap_or(usize::MAX);
363
364 let tool_result = if tool_result.len() <= tool_output_limit {
365 tool_result
366 } else {
367 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
368
369 format!(
370 "Tool result too long. The first {} bytes:\n\n{}",
371 truncated.len(),
372 truncated
373 )
374 };
375
376 self.tool_results.insert(
377 tool_use_id.clone(),
378 LanguageModelToolResult {
379 tool_use_id: tool_use_id.clone(),
380 tool_name,
381 content: tool_result.into(),
382 is_error: false,
383 },
384 );
385 self.pending_tool_uses_by_id.remove(&tool_use_id)
386 }
387 Err(err) => {
388 self.tool_results.insert(
389 tool_use_id.clone(),
390 LanguageModelToolResult {
391 tool_use_id: tool_use_id.clone(),
392 tool_name,
393 content: err.to_string().into(),
394 is_error: true,
395 },
396 );
397
398 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
399 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
400 }
401
402 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
403 }
404 }
405 }
406
407 pub fn attach_tool_uses(
408 &self,
409 message_id: MessageId,
410 request_message: &mut LanguageModelRequestMessage,
411 ) {
412 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
413 let mut found_tool_use = false;
414
415 for tool_use in tool_uses {
416 if self.tool_results.contains_key(&tool_use.id) {
417 if !found_tool_use {
418 // The API fails if a message contains a tool use without any (non-whitespace) text around it
419 match request_message.content.last_mut() {
420 Some(MessageContent::Text(txt)) => {
421 if txt.is_empty() {
422 txt.push_str(USING_TOOL_MARKER);
423 }
424 }
425 None | Some(_) => {
426 request_message
427 .content
428 .push(MessageContent::Text(USING_TOOL_MARKER.into()));
429 }
430 };
431 }
432
433 found_tool_use = true;
434
435 // Do not send tool uses until they are completed
436 request_message
437 .content
438 .push(MessageContent::ToolUse(tool_use.clone()));
439 } else {
440 log::debug!(
441 "skipped tool use {:?} because it is still pending",
442 tool_use
443 );
444 }
445 }
446 }
447 }
448
449 pub fn attach_tool_results(
450 &self,
451 message_id: MessageId,
452 request_message: &mut LanguageModelRequestMessage,
453 ) {
454 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
455 for tool_use_id in tool_uses {
456 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
457 request_message.content.push(MessageContent::ToolResult(
458 LanguageModelToolResult {
459 tool_use_id: tool_use_id.clone(),
460 tool_name: tool_result.tool_name.clone(),
461 is_error: tool_result.is_error,
462 content: if tool_result.content.is_empty() {
463 // Surprisingly, the API fails if we return an empty string here.
464 // It thinks we are sending a tool use without a tool result.
465 "<Tool returned an empty string>".into()
466 } else {
467 tool_result.content.clone()
468 },
469 },
470 ));
471 }
472 }
473 }
474 }
475}
476
477#[derive(Debug, Clone)]
478pub struct PendingToolUse {
479 pub id: LanguageModelToolUseId,
480 /// The ID of the Assistant message in which the tool use was requested.
481 #[allow(unused)]
482 pub assistant_message_id: MessageId,
483 pub name: Arc<str>,
484 pub ui_text: Arc<str>,
485 pub input: serde_json::Value,
486 pub status: PendingToolUseStatus,
487}
488
489#[derive(Debug, Clone)]
490pub struct Confirmation {
491 pub tool_use_id: LanguageModelToolUseId,
492 pub input: serde_json::Value,
493 pub ui_text: Arc<str>,
494 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
495 pub tool: Arc<dyn Tool>,
496}
497
498#[derive(Debug, Clone)]
499pub enum PendingToolUseStatus {
500 Idle,
501 NeedsConfirmation(Arc<Confirmation>),
502 Running { _task: Shared<Task<()>> },
503 Error(#[allow(unused)] Arc<str>),
504}
505
506impl PendingToolUseStatus {
507 pub fn is_idle(&self) -> bool {
508 matches!(self, PendingToolUseStatus::Idle)
509 }
510
511 pub fn is_error(&self) -> bool {
512 matches!(self, PendingToolUseStatus::Error(_))
513 }
514
515 pub fn needs_confirmation(&self) -> bool {
516 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
517 }
518}
519
520#[derive(Clone)]
521pub struct ToolUseMetadata {
522 pub model: Arc<dyn LanguageModel>,
523 pub thread_id: ThreadId,
524 pub prompt_id: PromptId,
525}