1use crate::{
2 thread::{MessageId, PromptId, ThreadId},
3 thread_store::SerializedMessage,
4};
5use agent_settings::CompletionMode;
6use anyhow::Result;
7use assistant_tool::{
8 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
9};
10use collections::HashMap;
11use futures::{FutureExt as _, future::Shared};
12use gpui::{App, Entity, SharedString, Task, Window};
13use icons::IconName;
14use language_model::{
15 ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
16 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
17 LanguageModelToolUseId, Role,
18};
19use project::Project;
20use std::sync::Arc;
21use util::truncate_lines_to_byte_limit;
22
23#[derive(Debug)]
24pub struct ToolUse {
25 pub id: LanguageModelToolUseId,
26 pub name: SharedString,
27 pub ui_text: SharedString,
28 pub status: ToolUseStatus,
29 pub input: serde_json::Value,
30 pub icon: icons::IconName,
31 pub needs_confirmation: bool,
32}
33
34pub struct ToolUseState {
35 tools: Entity<ToolWorkingSet>,
36 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
37 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
38 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
39 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
40 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
41}
42
43impl ToolUseState {
44 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
45 Self {
46 tools,
47 tool_uses_by_assistant_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 ///
59 /// If `window` is `None` (e.g., when in headless mode or when running evals),
60 /// tool cards won't be deserialized
61 pub fn from_serialized_messages(
62 tools: Entity<ToolWorkingSet>,
63 messages: &[SerializedMessage],
64 project: Entity<Project>,
65 window: Option<&mut Window>, // None in headless mode
66 cx: &mut App,
67 ) -> Self {
68 let mut this = Self::new(tools);
69 let mut tool_names_by_id = HashMap::default();
70 let mut window = window;
71
72 for message in messages {
73 match message.role {
74 Role::Assistant => {
75 if !message.tool_uses.is_empty() {
76 let tool_uses = message
77 .tool_uses
78 .iter()
79 .map(|tool_use| LanguageModelToolUse {
80 id: tool_use.id.clone(),
81 name: tool_use.name.clone().into(),
82 raw_input: tool_use.input.to_string(),
83 input: tool_use.input.clone(),
84 is_input_complete: true,
85 })
86 .collect::<Vec<_>>();
87
88 tool_names_by_id.extend(
89 tool_uses
90 .iter()
91 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
92 );
93
94 this.tool_uses_by_assistant_message
95 .insert(message.id, tool_uses);
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 this.tool_results.insert(
105 tool_use_id.clone(),
106 LanguageModelToolResult {
107 tool_use_id: tool_use_id.clone(),
108 tool_name: tool_use.clone(),
109 is_error: tool_result.is_error,
110 content: tool_result.content.clone(),
111 output: tool_result.output.clone(),
112 },
113 );
114
115 if let Some(window) = &mut window {
116 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
117 if let Some(output) = tool_result.output.clone() {
118 if let Some(card) = tool.deserialize_card(
119 output,
120 project.clone(),
121 window,
122 cx,
123 ) {
124 this.tool_result_cards.insert(tool_use_id, card);
125 }
126 }
127 }
128 }
129 }
130 }
131 }
132 Role::System | Role::User => {}
133 }
134 }
135
136 this
137 }
138
139 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
140 let mut cancelled_tool_uses = Vec::new();
141 self.pending_tool_uses_by_id
142 .retain(|tool_use_id, tool_use| {
143 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
144 return true;
145 }
146
147 let content = "Tool canceled by user".into();
148 self.tool_results.insert(
149 tool_use_id.clone(),
150 LanguageModelToolResult {
151 tool_use_id: tool_use_id.clone(),
152 tool_name: tool_use.name.clone(),
153 content,
154 output: None,
155 is_error: true,
156 },
157 );
158 cancelled_tool_uses.push(tool_use.clone());
159 false
160 });
161 cancelled_tool_uses
162 }
163
164 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
165 self.pending_tool_uses_by_id.values().collect()
166 }
167
168 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
169 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
170 return Vec::new();
171 };
172
173 let mut tool_uses = Vec::new();
174
175 for tool_use in tool_uses_for_message.iter() {
176 let tool_result = self.tool_results.get(&tool_use.id);
177
178 let status = (|| {
179 if let Some(tool_result) = tool_result {
180 let content = tool_result
181 .content
182 .to_str()
183 .map(|str| str.to_owned().into())
184 .unwrap_or_default();
185
186 return if tool_result.is_error {
187 ToolUseStatus::Error(content)
188 } else {
189 ToolUseStatus::Finished(content)
190 };
191 }
192
193 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
194 match pending_tool_use.status {
195 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
196 PendingToolUseStatus::NeedsConfirmation { .. } => {
197 ToolUseStatus::NeedsConfirmation
198 }
199 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
200 PendingToolUseStatus::Error(ref err) => {
201 ToolUseStatus::Error(err.clone().into())
202 }
203 PendingToolUseStatus::InputStillStreaming => {
204 ToolUseStatus::InputStillStreaming
205 }
206 }
207 } else {
208 ToolUseStatus::Pending
209 }
210 })();
211
212 let (icon, needs_confirmation) =
213 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
214 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
215 } else {
216 (IconName::Cog, false)
217 };
218
219 tool_uses.push(ToolUse {
220 id: tool_use.id.clone(),
221 name: tool_use.name.clone().into(),
222 ui_text: self.tool_ui_label(
223 &tool_use.name,
224 &tool_use.input,
225 tool_use.is_input_complete,
226 cx,
227 ),
228 input: tool_use.input.clone(),
229 status,
230 icon,
231 needs_confirmation,
232 })
233 }
234
235 tool_uses
236 }
237
238 pub fn tool_ui_label(
239 &self,
240 tool_name: &str,
241 input: &serde_json::Value,
242 is_input_complete: bool,
243 cx: &App,
244 ) -> SharedString {
245 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
246 if is_input_complete {
247 tool.ui_text(input).into()
248 } else {
249 tool.still_streaming_ui_text(input).into()
250 }
251 } else {
252 format!("Unknown tool {tool_name:?}").into()
253 }
254 }
255
256 pub fn tool_results_for_message(
257 &self,
258 assistant_message_id: MessageId,
259 ) -> Vec<&LanguageModelToolResult> {
260 let Some(tool_uses) = self
261 .tool_uses_by_assistant_message
262 .get(&assistant_message_id)
263 else {
264 return Vec::new();
265 };
266
267 tool_uses
268 .iter()
269 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
270 .collect()
271 }
272
273 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
274 self.tool_uses_by_assistant_message
275 .get(&assistant_message_id)
276 .map_or(false, |results| !results.is_empty())
277 }
278
279 pub fn tool_result(
280 &self,
281 tool_use_id: &LanguageModelToolUseId,
282 ) -> Option<&LanguageModelToolResult> {
283 self.tool_results.get(tool_use_id)
284 }
285
286 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
287 self.tool_result_cards.get(tool_use_id)
288 }
289
290 pub fn insert_tool_result_card(
291 &mut self,
292 tool_use_id: LanguageModelToolUseId,
293 card: AnyToolCard,
294 ) {
295 self.tool_result_cards.insert(tool_use_id, card);
296 }
297
298 pub fn request_tool_use(
299 &mut self,
300 assistant_message_id: MessageId,
301 tool_use: LanguageModelToolUse,
302 metadata: ToolUseMetadata,
303 cx: &App,
304 ) -> Arc<str> {
305 let tool_uses = self
306 .tool_uses_by_assistant_message
307 .entry(assistant_message_id)
308 .or_default();
309
310 let mut existing_tool_use_found = false;
311
312 for existing_tool_use in tool_uses.iter_mut() {
313 if existing_tool_use.id == tool_use.id {
314 *existing_tool_use = tool_use.clone();
315 existing_tool_use_found = true;
316 }
317 }
318
319 if !existing_tool_use_found {
320 tool_uses.push(tool_use.clone());
321 }
322
323 let status = if tool_use.is_input_complete {
324 self.tool_use_metadata_by_id
325 .insert(tool_use.id.clone(), metadata);
326
327 PendingToolUseStatus::Idle
328 } else {
329 PendingToolUseStatus::InputStillStreaming
330 };
331
332 let ui_text: Arc<str> = self
333 .tool_ui_label(
334 &tool_use.name,
335 &tool_use.input,
336 tool_use.is_input_complete,
337 cx,
338 )
339 .into();
340
341 let may_perform_edits = self
342 .tools
343 .read(cx)
344 .tool(&tool_use.name, cx)
345 .is_some_and(|tool| tool.may_perform_edits());
346
347 self.pending_tool_uses_by_id.insert(
348 tool_use.id.clone(),
349 PendingToolUse {
350 assistant_message_id,
351 id: tool_use.id,
352 name: tool_use.name.clone(),
353 ui_text: ui_text.clone(),
354 input: tool_use.input,
355 may_perform_edits,
356 status,
357 },
358 );
359
360 ui_text
361 }
362
363 pub fn run_pending_tool(
364 &mut self,
365 tool_use_id: LanguageModelToolUseId,
366 ui_text: SharedString,
367 task: Task<()>,
368 ) {
369 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
370 tool_use.ui_text = ui_text.into();
371 tool_use.status = PendingToolUseStatus::Running {
372 _task: task.shared(),
373 };
374 }
375 }
376
377 pub fn confirm_tool_use(
378 &mut self,
379 tool_use_id: LanguageModelToolUseId,
380 ui_text: impl Into<Arc<str>>,
381 input: serde_json::Value,
382 request: Arc<LanguageModelRequest>,
383 tool: Arc<dyn Tool>,
384 ) {
385 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
386 let ui_text = ui_text.into();
387 tool_use.ui_text = ui_text.clone();
388 let confirmation = Confirmation {
389 tool_use_id,
390 input,
391 request,
392 tool,
393 ui_text,
394 };
395 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
396 }
397 }
398
399 pub fn insert_tool_output(
400 &mut self,
401 tool_use_id: LanguageModelToolUseId,
402 tool_name: Arc<str>,
403 output: Result<ToolResultOutput>,
404 configured_model: Option<&ConfiguredModel>,
405 completion_mode: CompletionMode,
406 ) -> Option<PendingToolUse> {
407 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
408
409 telemetry::event!(
410 "Agent Tool Finished",
411 model = metadata
412 .as_ref()
413 .map(|metadata| metadata.model.telemetry_id()),
414 model_provider = metadata
415 .as_ref()
416 .map(|metadata| metadata.model.provider_id().to_string()),
417 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
418 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
419 tool_name,
420 success = output.is_ok()
421 );
422
423 match output {
424 Ok(output) => {
425 let tool_result = output.content;
426 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
427
428 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
429
430 // Protect from overly large output
431 let tool_output_limit = configured_model
432 .map(|model| {
433 model.model.max_token_count_for_mode(completion_mode.into()) as usize
434 * BYTES_PER_TOKEN_ESTIMATE
435 })
436 .unwrap_or(usize::MAX);
437
438 let content = match tool_result {
439 ToolResultContent::Text(text) => {
440 let text = if text.len() < tool_output_limit {
441 text
442 } else {
443 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
444 format!(
445 "Tool result too long. The first {} bytes:\n\n{}",
446 truncated.len(),
447 truncated
448 )
449 };
450 LanguageModelToolResultContent::Text(text.into())
451 }
452 ToolResultContent::Image(language_model_image) => {
453 if language_model_image.estimate_tokens() < tool_output_limit {
454 LanguageModelToolResultContent::Image(language_model_image)
455 } else {
456 self.tool_results.insert(
457 tool_use_id.clone(),
458 LanguageModelToolResult {
459 tool_use_id: tool_use_id.clone(),
460 tool_name,
461 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
462 is_error: true,
463 output: None,
464 },
465 );
466
467 return old_use;
468 }
469 }
470 };
471
472 self.tool_results.insert(
473 tool_use_id.clone(),
474 LanguageModelToolResult {
475 tool_use_id: tool_use_id.clone(),
476 tool_name,
477 content,
478 is_error: false,
479 output: output.output,
480 },
481 );
482
483 old_use
484 }
485 Err(err) => {
486 self.tool_results.insert(
487 tool_use_id.clone(),
488 LanguageModelToolResult {
489 tool_use_id: tool_use_id.clone(),
490 tool_name,
491 content: LanguageModelToolResultContent::Text(err.to_string().into()),
492 is_error: true,
493 output: None,
494 },
495 );
496
497 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
498 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
499 }
500
501 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
502 }
503 }
504 }
505
506 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
507 self.tool_uses_by_assistant_message
508 .contains_key(&assistant_message_id)
509 }
510
511 pub fn tool_results(
512 &self,
513 assistant_message_id: MessageId,
514 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
515 self.tool_uses_by_assistant_message
516 .get(&assistant_message_id)
517 .into_iter()
518 .flatten()
519 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
520 }
521}
522
523#[derive(Debug, Clone)]
524pub struct PendingToolUse {
525 pub id: LanguageModelToolUseId,
526 /// The ID of the Assistant message in which the tool use was requested.
527 #[allow(unused)]
528 pub assistant_message_id: MessageId,
529 pub name: Arc<str>,
530 pub ui_text: Arc<str>,
531 pub input: serde_json::Value,
532 pub status: PendingToolUseStatus,
533 pub may_perform_edits: bool,
534}
535
536#[derive(Debug, Clone)]
537pub struct Confirmation {
538 pub tool_use_id: LanguageModelToolUseId,
539 pub input: serde_json::Value,
540 pub ui_text: Arc<str>,
541 pub request: Arc<LanguageModelRequest>,
542 pub tool: Arc<dyn Tool>,
543}
544
545#[derive(Debug, Clone)]
546pub enum PendingToolUseStatus {
547 InputStillStreaming,
548 Idle,
549 NeedsConfirmation(Arc<Confirmation>),
550 Running { _task: Shared<Task<()>> },
551 Error(#[allow(unused)] Arc<str>),
552}
553
554impl PendingToolUseStatus {
555 pub fn is_idle(&self) -> bool {
556 matches!(self, PendingToolUseStatus::Idle)
557 }
558
559 pub fn is_error(&self) -> bool {
560 matches!(self, PendingToolUseStatus::Error(_))
561 }
562
563 pub fn needs_confirmation(&self) -> bool {
564 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
565 }
566}
567
568#[derive(Clone)]
569pub struct ToolUseMetadata {
570 pub model: Arc<dyn LanguageModel>,
571 pub thread_id: ThreadId,
572 pub prompt_id: PromptId,
573}