1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
34 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
35 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
36 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
37 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
38}
39
40impl ToolUseState {
41 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
42 Self {
43 tools,
44 tool_uses_by_assistant_message: HashMap::default(),
45 tool_uses_by_user_message: HashMap::default(),
46 tool_results: HashMap::default(),
47 pending_tool_uses_by_id: HashMap::default(),
48 tool_result_cards: HashMap::default(),
49 tool_use_metadata_by_id: HashMap::default(),
50 }
51 }
52
53 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
54 ///
55 /// Accepts a function to filter the tools that should be used to populate the state.
56 pub fn from_serialized_messages(
57 tools: Entity<ToolWorkingSet>,
58 messages: &[SerializedMessage],
59 mut filter_by_tool_name: impl FnMut(&str) -> bool,
60 ) -> Self {
61 let mut this = Self::new(tools);
62 let mut tool_names_by_id = HashMap::default();
63
64 for message in messages {
65 match message.role {
66 Role::Assistant => {
67 if !message.tool_uses.is_empty() {
68 let tool_uses = message
69 .tool_uses
70 .iter()
71 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
72 .map(|tool_use| LanguageModelToolUse {
73 id: tool_use.id.clone(),
74 name: tool_use.name.clone().into(),
75 raw_input: tool_use.input.to_string(),
76 input: tool_use.input.clone(),
77 is_input_complete: true,
78 })
79 .collect::<Vec<_>>();
80
81 tool_names_by_id.extend(
82 tool_uses
83 .iter()
84 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
85 );
86
87 this.tool_uses_by_assistant_message
88 .insert(message.id, tool_uses);
89 }
90 }
91 Role::User => {
92 if !message.tool_results.is_empty() {
93 let tool_uses_by_user_message = this
94 .tool_uses_by_user_message
95 .entry(message.id)
96 .or_default();
97
98 for tool_result in &message.tool_results {
99 let tool_use_id = tool_result.tool_use_id.clone();
100 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101 log::warn!("no tool name found for tool use: {tool_use_id:?}");
102 continue;
103 };
104
105 if !(filter_by_tool_name)(tool_use.as_ref()) {
106 continue;
107 }
108
109 tool_uses_by_user_message.push(tool_use_id.clone());
110 this.tool_results.insert(
111 tool_use_id.clone(),
112 LanguageModelToolResult {
113 tool_use_id,
114 tool_name: tool_use.clone(),
115 is_error: tool_result.is_error,
116 content: tool_result.content.clone(),
117 },
118 );
119 }
120 }
121 }
122 Role::System => {}
123 }
124 }
125
126 this
127 }
128
129 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130 let mut pending_tools = Vec::new();
131 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132 self.tool_results.insert(
133 tool_use_id.clone(),
134 LanguageModelToolResult {
135 tool_use_id,
136 tool_name: tool_use.name.clone(),
137 content: "Tool canceled by user".into(),
138 is_error: true,
139 },
140 );
141 pending_tools.push(tool_use.clone());
142 }
143 pending_tools
144 }
145
146 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
147 self.pending_tool_uses_by_id.values().collect()
148 }
149
150 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
151 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
152 return Vec::new();
153 };
154
155 let mut tool_uses = Vec::new();
156
157 for tool_use in tool_uses_for_message.iter() {
158 let tool_result = self.tool_results.get(&tool_use.id);
159
160 let status = (|| {
161 if let Some(tool_result) = tool_result {
162 return if tool_result.is_error {
163 ToolUseStatus::Error(tool_result.content.clone().into())
164 } else {
165 ToolUseStatus::Finished(tool_result.content.clone().into())
166 };
167 }
168
169 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
170 match pending_tool_use.status {
171 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
172 PendingToolUseStatus::NeedsConfirmation { .. } => {
173 ToolUseStatus::NeedsConfirmation
174 }
175 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
176 PendingToolUseStatus::Error(ref err) => {
177 ToolUseStatus::Error(err.clone().into())
178 }
179 PendingToolUseStatus::InputStillStreaming => {
180 ToolUseStatus::InputStillStreaming
181 }
182 }
183 } else {
184 ToolUseStatus::Pending
185 }
186 })();
187
188 let (icon, needs_confirmation) =
189 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
190 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
191 } else {
192 (IconName::Cog, false)
193 };
194
195 tool_uses.push(ToolUse {
196 id: tool_use.id.clone(),
197 name: tool_use.name.clone().into(),
198 ui_text: self.tool_ui_label(
199 &tool_use.name,
200 &tool_use.input,
201 tool_use.is_input_complete,
202 cx,
203 ),
204 input: tool_use.input.clone(),
205 status,
206 icon,
207 needs_confirmation,
208 })
209 }
210
211 tool_uses
212 }
213
214 pub fn tool_ui_label(
215 &self,
216 tool_name: &str,
217 input: &serde_json::Value,
218 is_input_complete: bool,
219 cx: &App,
220 ) -> SharedString {
221 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
222 if is_input_complete {
223 tool.ui_text(input).into()
224 } else {
225 tool.still_streaming_ui_text(input).into()
226 }
227 } else {
228 format!("Unknown tool {tool_name:?}").into()
229 }
230 }
231
232 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
233 let empty = Vec::new();
234
235 self.tool_uses_by_user_message
236 .get(&message_id)
237 .unwrap_or(&empty)
238 .iter()
239 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
240 .collect()
241 }
242
243 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
244 self.tool_uses_by_user_message
245 .get(&message_id)
246 .map_or(false, |results| !results.is_empty())
247 }
248
249 pub fn tool_result(
250 &self,
251 tool_use_id: &LanguageModelToolUseId,
252 ) -> Option<&LanguageModelToolResult> {
253 self.tool_results.get(tool_use_id)
254 }
255
256 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
257 self.tool_result_cards.get(tool_use_id)
258 }
259
260 pub fn insert_tool_result_card(
261 &mut self,
262 tool_use_id: LanguageModelToolUseId,
263 card: AnyToolCard,
264 ) {
265 self.tool_result_cards.insert(tool_use_id, card);
266 }
267
268 pub fn request_tool_use(
269 &mut self,
270 assistant_message_id: MessageId,
271 tool_use: LanguageModelToolUse,
272 metadata: ToolUseMetadata,
273 cx: &App,
274 ) -> Arc<str> {
275 let tool_uses = self
276 .tool_uses_by_assistant_message
277 .entry(assistant_message_id)
278 .or_default();
279
280 let mut existing_tool_use_found = false;
281
282 for existing_tool_use in tool_uses.iter_mut() {
283 if existing_tool_use.id == tool_use.id {
284 *existing_tool_use = tool_use.clone();
285 existing_tool_use_found = true;
286 }
287 }
288
289 if !existing_tool_use_found {
290 tool_uses.push(tool_use.clone());
291 }
292
293 let status = if tool_use.is_input_complete {
294 self.tool_use_metadata_by_id
295 .insert(tool_use.id.clone(), metadata);
296
297 // The tool use is being requested by the Assistant, so we want to
298 // attach the tool results to the next user message.
299 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
300 self.tool_uses_by_user_message
301 .entry(next_user_message_id)
302 .or_default()
303 .push(tool_use.id.clone());
304
305 PendingToolUseStatus::Idle
306 } else {
307 PendingToolUseStatus::InputStillStreaming
308 };
309
310 let ui_text: Arc<str> = self
311 .tool_ui_label(
312 &tool_use.name,
313 &tool_use.input,
314 tool_use.is_input_complete,
315 cx,
316 )
317 .into();
318
319 self.pending_tool_uses_by_id.insert(
320 tool_use.id.clone(),
321 PendingToolUse {
322 assistant_message_id,
323 id: tool_use.id,
324 name: tool_use.name.clone(),
325 ui_text: ui_text.clone(),
326 input: tool_use.input,
327 status,
328 },
329 );
330
331 ui_text
332 }
333
334 pub fn run_pending_tool(
335 &mut self,
336 tool_use_id: LanguageModelToolUseId,
337 ui_text: SharedString,
338 task: Task<()>,
339 ) {
340 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
341 tool_use.ui_text = ui_text.into();
342 tool_use.status = PendingToolUseStatus::Running {
343 _task: task.shared(),
344 };
345 }
346 }
347
348 pub fn confirm_tool_use(
349 &mut self,
350 tool_use_id: LanguageModelToolUseId,
351 ui_text: impl Into<Arc<str>>,
352 input: serde_json::Value,
353 messages: Arc<Vec<LanguageModelRequestMessage>>,
354 tool: Arc<dyn Tool>,
355 ) {
356 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
357 let ui_text = ui_text.into();
358 tool_use.ui_text = ui_text.clone();
359 let confirmation = Confirmation {
360 tool_use_id,
361 input,
362 messages,
363 tool,
364 ui_text,
365 };
366 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
367 }
368 }
369
370 pub fn insert_tool_output(
371 &mut self,
372 tool_use_id: LanguageModelToolUseId,
373 tool_name: Arc<str>,
374 output: Result<String>,
375 cx: &App,
376 ) -> Option<PendingToolUse> {
377 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
378
379 telemetry::event!(
380 "Agent Tool Finished",
381 model = metadata
382 .as_ref()
383 .map(|metadata| metadata.model.telemetry_id()),
384 model_provider = metadata
385 .as_ref()
386 .map(|metadata| metadata.model.provider_id().to_string()),
387 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
388 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
389 tool_name,
390 success = output.is_ok()
391 );
392
393 match output {
394 Ok(tool_result) => {
395 let model_registry = LanguageModelRegistry::read_global(cx);
396
397 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
398
399 // Protect from clearly large output
400 let tool_output_limit = model_registry
401 .default_model()
402 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
403 .unwrap_or(usize::MAX);
404
405 let tool_result = if tool_result.len() <= tool_output_limit {
406 tool_result
407 } else {
408 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
409
410 format!(
411 "Tool result too long. The first {} bytes:\n\n{}",
412 truncated.len(),
413 truncated
414 )
415 };
416
417 self.tool_results.insert(
418 tool_use_id.clone(),
419 LanguageModelToolResult {
420 tool_use_id: tool_use_id.clone(),
421 tool_name,
422 content: tool_result.into(),
423 is_error: false,
424 },
425 );
426 self.pending_tool_uses_by_id.remove(&tool_use_id)
427 }
428 Err(err) => {
429 self.tool_results.insert(
430 tool_use_id.clone(),
431 LanguageModelToolResult {
432 tool_use_id: tool_use_id.clone(),
433 tool_name,
434 content: err.to_string().into(),
435 is_error: true,
436 },
437 );
438
439 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
440 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
441 }
442
443 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
444 }
445 }
446 }
447
448 pub fn attach_tool_uses(
449 &self,
450 message_id: MessageId,
451 request_message: &mut LanguageModelRequestMessage,
452 ) {
453 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
454 for tool_use in tool_uses {
455 if self.tool_results.contains_key(&tool_use.id) {
456 // Do not send tool uses until they are completed
457 request_message
458 .content
459 .push(MessageContent::ToolUse(tool_use.clone()));
460 } else {
461 log::debug!(
462 "skipped tool use {:?} because it is still pending",
463 tool_use
464 );
465 }
466 }
467 }
468 }
469
470 pub fn attach_tool_results(
471 &self,
472 message_id: MessageId,
473 request_message: &mut LanguageModelRequestMessage,
474 ) {
475 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
476 for tool_use_id in tool_uses {
477 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
478 request_message.content.push(MessageContent::ToolResult(
479 LanguageModelToolResult {
480 tool_use_id: tool_use_id.clone(),
481 tool_name: tool_result.tool_name.clone(),
482 is_error: tool_result.is_error,
483 content: if tool_result.content.is_empty() {
484 // Surprisingly, the API fails if we return an empty string here.
485 // It thinks we are sending a tool use without a tool result.
486 "<Tool returned an empty string>".into()
487 } else {
488 tool_result.content.clone()
489 },
490 },
491 ));
492 }
493 }
494 }
495 }
496}
497
498#[derive(Debug, Clone)]
499pub struct PendingToolUse {
500 pub id: LanguageModelToolUseId,
501 /// The ID of the Assistant message in which the tool use was requested.
502 #[allow(unused)]
503 pub assistant_message_id: MessageId,
504 pub name: Arc<str>,
505 pub ui_text: Arc<str>,
506 pub input: serde_json::Value,
507 pub status: PendingToolUseStatus,
508}
509
510#[derive(Debug, Clone)]
511pub struct Confirmation {
512 pub tool_use_id: LanguageModelToolUseId,
513 pub input: serde_json::Value,
514 pub ui_text: Arc<str>,
515 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
516 pub tool: Arc<dyn Tool>,
517}
518
519#[derive(Debug, Clone)]
520pub enum PendingToolUseStatus {
521 InputStillStreaming,
522 Idle,
523 NeedsConfirmation(Arc<Confirmation>),
524 Running { _task: Shared<Task<()>> },
525 Error(#[allow(unused)] Arc<str>),
526}
527
528impl PendingToolUseStatus {
529 pub fn is_idle(&self) -> bool {
530 matches!(self, PendingToolUseStatus::Idle)
531 }
532
533 pub fn is_error(&self) -> bool {
534 matches!(self, PendingToolUseStatus::Error(_))
535 }
536
537 pub fn needs_confirmation(&self) -> bool {
538 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
539 }
540}
541
542#[derive(Clone)]
543pub struct ToolUseMetadata {
544 pub model: Arc<dyn LanguageModel>,
545 pub thread_id: ThreadId,
546 pub prompt_id: PromptId,
547}