1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
34 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
35 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
36 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
37}
38
39impl ToolUseState {
40 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
41 Self {
42 tools,
43 tool_uses_by_assistant_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 tool_result_cards: HashMap::default(),
47 tool_use_metadata_by_id: HashMap::default(),
48 }
49 }
50
51 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
52 ///
53 /// Accepts a function to filter the tools that should be used to populate the state.
54 pub fn from_serialized_messages(
55 tools: Entity<ToolWorkingSet>,
56 messages: &[SerializedMessage],
57 ) -> Self {
58 let mut this = Self::new(tools);
59 let mut tool_names_by_id = HashMap::default();
60
61 for message in messages {
62 match message.role {
63 Role::Assistant => {
64 if !message.tool_uses.is_empty() {
65 let tool_uses = message
66 .tool_uses
67 .iter()
68 .map(|tool_use| LanguageModelToolUse {
69 id: tool_use.id.clone(),
70 name: tool_use.name.clone().into(),
71 raw_input: tool_use.input.to_string(),
72 input: tool_use.input.clone(),
73 is_input_complete: true,
74 })
75 .collect::<Vec<_>>();
76
77 tool_names_by_id.extend(
78 tool_uses
79 .iter()
80 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
81 );
82
83 this.tool_uses_by_assistant_message
84 .insert(message.id, tool_uses);
85
86 for tool_result in &message.tool_results {
87 let tool_use_id = tool_result.tool_use_id.clone();
88 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
89 log::warn!("no tool name found for tool use: {tool_use_id:?}");
90 continue;
91 };
92
93 this.tool_results.insert(
94 tool_use_id.clone(),
95 LanguageModelToolResult {
96 tool_use_id,
97 tool_name: tool_use.clone(),
98 is_error: tool_result.is_error,
99 content: tool_result.content.clone(),
100 },
101 );
102 }
103 }
104 }
105 Role::System | Role::User => {}
106 }
107 }
108
109 this
110 }
111
112 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
113 let mut cancelled_tool_uses = Vec::new();
114 self.pending_tool_uses_by_id
115 .retain(|tool_use_id, tool_use| {
116 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
117 return true;
118 }
119
120 let content = "Tool canceled by user".into();
121 self.tool_results.insert(
122 tool_use_id.clone(),
123 LanguageModelToolResult {
124 tool_use_id: tool_use_id.clone(),
125 tool_name: tool_use.name.clone(),
126 content,
127 is_error: true,
128 },
129 );
130 cancelled_tool_uses.push(tool_use.clone());
131 false
132 });
133 cancelled_tool_uses
134 }
135
136 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
137 self.pending_tool_uses_by_id.values().collect()
138 }
139
140 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
141 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
142 return Vec::new();
143 };
144
145 let mut tool_uses = Vec::new();
146
147 for tool_use in tool_uses_for_message.iter() {
148 let tool_result = self.tool_results.get(&tool_use.id);
149
150 let status = (|| {
151 if let Some(tool_result) = tool_result {
152 return if tool_result.is_error {
153 ToolUseStatus::Error(tool_result.content.clone().into())
154 } else {
155 ToolUseStatus::Finished(tool_result.content.clone().into())
156 };
157 }
158
159 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
160 match pending_tool_use.status {
161 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
162 PendingToolUseStatus::NeedsConfirmation { .. } => {
163 ToolUseStatus::NeedsConfirmation
164 }
165 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
166 PendingToolUseStatus::Error(ref err) => {
167 ToolUseStatus::Error(err.clone().into())
168 }
169 PendingToolUseStatus::InputStillStreaming => {
170 ToolUseStatus::InputStillStreaming
171 }
172 }
173 } else {
174 ToolUseStatus::Pending
175 }
176 })();
177
178 let (icon, needs_confirmation) =
179 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
180 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
181 } else {
182 (IconName::Cog, false)
183 };
184
185 tool_uses.push(ToolUse {
186 id: tool_use.id.clone(),
187 name: tool_use.name.clone().into(),
188 ui_text: self.tool_ui_label(
189 &tool_use.name,
190 &tool_use.input,
191 tool_use.is_input_complete,
192 cx,
193 ),
194 input: tool_use.input.clone(),
195 status,
196 icon,
197 needs_confirmation,
198 })
199 }
200
201 tool_uses
202 }
203
204 pub fn tool_ui_label(
205 &self,
206 tool_name: &str,
207 input: &serde_json::Value,
208 is_input_complete: bool,
209 cx: &App,
210 ) -> SharedString {
211 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
212 if is_input_complete {
213 tool.ui_text(input).into()
214 } else {
215 tool.still_streaming_ui_text(input).into()
216 }
217 } else {
218 format!("Unknown tool {tool_name:?}").into()
219 }
220 }
221
222 pub fn tool_results_for_message(
223 &self,
224 assistant_message_id: MessageId,
225 ) -> Vec<&LanguageModelToolResult> {
226 let Some(tool_uses) = self
227 .tool_uses_by_assistant_message
228 .get(&assistant_message_id)
229 else {
230 return Vec::new();
231 };
232
233 tool_uses
234 .iter()
235 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
236 .collect()
237 }
238
239 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
240 self.tool_uses_by_assistant_message
241 .get(&assistant_message_id)
242 .map_or(false, |results| !results.is_empty())
243 }
244
245 pub fn tool_result(
246 &self,
247 tool_use_id: &LanguageModelToolUseId,
248 ) -> Option<&LanguageModelToolResult> {
249 self.tool_results.get(tool_use_id)
250 }
251
252 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
253 self.tool_result_cards.get(tool_use_id)
254 }
255
256 pub fn insert_tool_result_card(
257 &mut self,
258 tool_use_id: LanguageModelToolUseId,
259 card: AnyToolCard,
260 ) {
261 self.tool_result_cards.insert(tool_use_id, card);
262 }
263
264 pub fn request_tool_use(
265 &mut self,
266 assistant_message_id: MessageId,
267 tool_use: LanguageModelToolUse,
268 metadata: ToolUseMetadata,
269 cx: &App,
270 ) -> Arc<str> {
271 let tool_uses = self
272 .tool_uses_by_assistant_message
273 .entry(assistant_message_id)
274 .or_default();
275
276 let mut existing_tool_use_found = false;
277
278 for existing_tool_use in tool_uses.iter_mut() {
279 if existing_tool_use.id == tool_use.id {
280 *existing_tool_use = tool_use.clone();
281 existing_tool_use_found = true;
282 }
283 }
284
285 if !existing_tool_use_found {
286 tool_uses.push(tool_use.clone());
287 }
288
289 let status = if tool_use.is_input_complete {
290 self.tool_use_metadata_by_id
291 .insert(tool_use.id.clone(), metadata);
292
293 PendingToolUseStatus::Idle
294 } else {
295 PendingToolUseStatus::InputStillStreaming
296 };
297
298 let ui_text: Arc<str> = self
299 .tool_ui_label(
300 &tool_use.name,
301 &tool_use.input,
302 tool_use.is_input_complete,
303 cx,
304 )
305 .into();
306
307 self.pending_tool_uses_by_id.insert(
308 tool_use.id.clone(),
309 PendingToolUse {
310 assistant_message_id,
311 id: tool_use.id,
312 name: tool_use.name.clone(),
313 ui_text: ui_text.clone(),
314 input: tool_use.input,
315 status,
316 },
317 );
318
319 ui_text
320 }
321
322 pub fn run_pending_tool(
323 &mut self,
324 tool_use_id: LanguageModelToolUseId,
325 ui_text: SharedString,
326 task: Task<()>,
327 ) {
328 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
329 tool_use.ui_text = ui_text.into();
330 tool_use.status = PendingToolUseStatus::Running {
331 _task: task.shared(),
332 };
333 }
334 }
335
336 pub fn confirm_tool_use(
337 &mut self,
338 tool_use_id: LanguageModelToolUseId,
339 ui_text: impl Into<Arc<str>>,
340 input: serde_json::Value,
341 messages: Arc<Vec<LanguageModelRequestMessage>>,
342 tool: Arc<dyn Tool>,
343 ) {
344 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
345 let ui_text = ui_text.into();
346 tool_use.ui_text = ui_text.clone();
347 let confirmation = Confirmation {
348 tool_use_id,
349 input,
350 messages,
351 tool,
352 ui_text,
353 };
354 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
355 }
356 }
357
358 pub fn insert_tool_output(
359 &mut self,
360 tool_use_id: LanguageModelToolUseId,
361 tool_name: Arc<str>,
362 output: Result<String>,
363 configured_model: Option<&ConfiguredModel>,
364 ) -> Option<PendingToolUse> {
365 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
366
367 telemetry::event!(
368 "Agent Tool Finished",
369 model = metadata
370 .as_ref()
371 .map(|metadata| metadata.model.telemetry_id()),
372 model_provider = metadata
373 .as_ref()
374 .map(|metadata| metadata.model.provider_id().to_string()),
375 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
376 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
377 tool_name,
378 success = output.is_ok()
379 );
380
381 match output {
382 Ok(tool_result) => {
383 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
384
385 // Protect from clearly large output
386 let tool_output_limit = configured_model
387 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
388 .unwrap_or(usize::MAX);
389
390 let tool_result = if tool_result.len() <= tool_output_limit {
391 tool_result
392 } else {
393 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
394
395 format!(
396 "Tool result too long. The first {} bytes:\n\n{}",
397 truncated.len(),
398 truncated
399 )
400 };
401
402 self.tool_results.insert(
403 tool_use_id.clone(),
404 LanguageModelToolResult {
405 tool_use_id: tool_use_id.clone(),
406 tool_name,
407 content: tool_result.into(),
408 is_error: false,
409 },
410 );
411 self.pending_tool_uses_by_id.remove(&tool_use_id)
412 }
413 Err(err) => {
414 self.tool_results.insert(
415 tool_use_id.clone(),
416 LanguageModelToolResult {
417 tool_use_id: tool_use_id.clone(),
418 tool_name,
419 content: err.to_string().into(),
420 is_error: true,
421 },
422 );
423
424 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
425 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
426 }
427
428 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
429 }
430 }
431 }
432
433 pub fn attach_tool_uses(
434 &self,
435 message_id: MessageId,
436 request_message: &mut LanguageModelRequestMessage,
437 ) {
438 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
439 for tool_use in tool_uses {
440 if self.tool_results.contains_key(&tool_use.id) {
441 // Do not send tool uses until they are completed
442 request_message
443 .content
444 .push(MessageContent::ToolUse(tool_use.clone()));
445 } else {
446 log::debug!(
447 "skipped tool use {:?} because it is still pending",
448 tool_use
449 );
450 }
451 }
452 }
453 }
454
455 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
456 self.tool_uses_by_assistant_message
457 .contains_key(&assistant_message_id)
458 }
459
460 pub fn tool_results_message(
461 &self,
462 assistant_message_id: MessageId,
463 ) -> Option<LanguageModelRequestMessage> {
464 let tool_uses = self
465 .tool_uses_by_assistant_message
466 .get(&assistant_message_id)?;
467
468 if tool_uses.is_empty() {
469 return None;
470 }
471
472 let mut request_message = LanguageModelRequestMessage {
473 role: Role::User,
474 content: vec![],
475 cache: false,
476 };
477
478 for tool_use in tool_uses {
479 if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
480 request_message
481 .content
482 .push(MessageContent::ToolResult(LanguageModelToolResult {
483 tool_use_id: tool_use.id.clone(),
484 tool_name: tool_result.tool_name.clone(),
485 is_error: tool_result.is_error,
486 content: if tool_result.content.is_empty() {
487 // Surprisingly, the API fails if we return an empty string here.
488 // It thinks we are sending a tool use without a tool result.
489 "<Tool returned an empty string>".into()
490 } else {
491 tool_result.content.clone()
492 },
493 }));
494 }
495 }
496
497 Some(request_message)
498 }
499}
500
501#[derive(Debug, Clone)]
502pub struct PendingToolUse {
503 pub id: LanguageModelToolUseId,
504 /// The ID of the Assistant message in which the tool use was requested.
505 #[allow(unused)]
506 pub assistant_message_id: MessageId,
507 pub name: Arc<str>,
508 pub ui_text: Arc<str>,
509 pub input: serde_json::Value,
510 pub status: PendingToolUseStatus,
511}
512
513#[derive(Debug, Clone)]
514pub struct Confirmation {
515 pub tool_use_id: LanguageModelToolUseId,
516 pub input: serde_json::Value,
517 pub ui_text: Arc<str>,
518 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
519 pub tool: Arc<dyn Tool>,
520}
521
522#[derive(Debug, Clone)]
523pub enum PendingToolUseStatus {
524 InputStillStreaming,
525 Idle,
526 NeedsConfirmation(Arc<Confirmation>),
527 Running { _task: Shared<Task<()>> },
528 Error(#[allow(unused)] Arc<str>),
529}
530
531impl PendingToolUseStatus {
532 pub fn is_idle(&self) -> bool {
533 matches!(self, PendingToolUseStatus::Idle)
534 }
535
536 pub fn is_error(&self) -> bool {
537 matches!(self, PendingToolUseStatus::Error(_))
538 }
539
540 pub fn needs_confirmation(&self) -> bool {
541 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
542 }
543}
544
545#[derive(Clone)]
546pub struct ToolUseMetadata {
547 pub model: Arc<dyn LanguageModel>,
548 pub thread_id: ThreadId,
549 pub prompt_id: PromptId,
550}