1use crate::{
2 thread::{MessageId, PromptId, ThreadId},
3 thread_store::SerializedMessage,
4};
5use agent_settings::CompletionMode;
6use anyhow::Result;
7use assistant_tool::{
8 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
9};
10use collections::HashMap;
11use futures::{FutureExt as _, future::Shared};
12use gpui::{App, Entity, SharedString, Task, Window};
13use icons::IconName;
14use language_model::{
15 ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
16 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
17 LanguageModelToolUseId, Role,
18};
19use project::Project;
20use std::sync::Arc;
21use util::truncate_lines_to_byte_limit;
22
23#[derive(Debug)]
24pub struct ToolUse {
25 pub id: LanguageModelToolUseId,
26 pub name: SharedString,
27 pub ui_text: SharedString,
28 pub status: ToolUseStatus,
29 pub input: serde_json::Value,
30 pub icon: icons::IconName,
31 pub needs_confirmation: bool,
32}
33
34pub struct ToolUseState {
35 tools: Entity<ToolWorkingSet>,
36 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
37 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
38 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
39 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
40 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
41}
42
43impl ToolUseState {
44 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
45 Self {
46 tools,
47 tool_uses_by_assistant_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 ///
59 /// If `window` is `None` (e.g., when in headless mode or when running evals),
60 /// tool cards won't be deserialized
61 pub fn from_serialized_messages(
62 tools: Entity<ToolWorkingSet>,
63 messages: &[SerializedMessage],
64 project: Entity<Project>,
65 window: Option<&mut Window>, // None in headless mode
66 cx: &mut App,
67 ) -> Self {
68 let mut this = Self::new(tools);
69 let mut tool_names_by_id = HashMap::default();
70 let mut window = window;
71
72 for message in messages {
73 match message.role {
74 Role::Assistant => {
75 if !message.tool_uses.is_empty() {
76 let tool_uses = message
77 .tool_uses
78 .iter()
79 .map(|tool_use| LanguageModelToolUse {
80 id: tool_use.id.clone(),
81 name: tool_use.name.clone().into(),
82 raw_input: tool_use.input.to_string(),
83 input: tool_use.input.clone(),
84 is_input_complete: true,
85 })
86 .collect::<Vec<_>>();
87
88 tool_names_by_id.extend(
89 tool_uses
90 .iter()
91 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
92 );
93
94 this.tool_uses_by_assistant_message
95 .insert(message.id, tool_uses);
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 this.tool_results.insert(
105 tool_use_id.clone(),
106 LanguageModelToolResult {
107 tool_use_id: tool_use_id.clone(),
108 tool_name: tool_use.clone(),
109 is_error: tool_result.is_error,
110 content: tool_result.content.clone(),
111 output: tool_result.output.clone(),
112 },
113 );
114
115 if let Some(window) = &mut window
116 && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
117 && let Some(output) = tool_result.output.clone()
118 && let Some(card) = tool.deserialize_card(
119 output,
120 project.clone(),
121 window,
122 cx,
123 ) {
124 this.tool_result_cards.insert(tool_use_id, card);
125 }
126 }
127 }
128 }
129 Role::System | Role::User => {}
130 }
131 }
132
133 this
134 }
135
136 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
137 let mut canceled_tool_uses = Vec::new();
138 self.pending_tool_uses_by_id
139 .retain(|tool_use_id, tool_use| {
140 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
141 return true;
142 }
143
144 let content = "Tool canceled by user".into();
145 self.tool_results.insert(
146 tool_use_id.clone(),
147 LanguageModelToolResult {
148 tool_use_id: tool_use_id.clone(),
149 tool_name: tool_use.name.clone(),
150 content,
151 output: None,
152 is_error: true,
153 },
154 );
155 canceled_tool_uses.push(tool_use.clone());
156 false
157 });
158 canceled_tool_uses
159 }
160
161 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
162 self.pending_tool_uses_by_id.values().collect()
163 }
164
165 pub fn tool_uses_for_message(
166 &self,
167 id: MessageId,
168 project: &Entity<Project>,
169 cx: &App,
170 ) -> Vec<ToolUse> {
171 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
172 return Vec::new();
173 };
174
175 let mut tool_uses = Vec::new();
176
177 for tool_use in tool_uses_for_message.iter() {
178 let tool_result = self.tool_results.get(&tool_use.id);
179
180 let status = (|| {
181 if let Some(tool_result) = tool_result {
182 let content = tool_result
183 .content
184 .to_str()
185 .map(|str| str.to_owned().into())
186 .unwrap_or_default();
187
188 return if tool_result.is_error {
189 ToolUseStatus::Error(content)
190 } else {
191 ToolUseStatus::Finished(content)
192 };
193 }
194
195 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
196 match pending_tool_use.status {
197 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
198 PendingToolUseStatus::NeedsConfirmation { .. } => {
199 ToolUseStatus::NeedsConfirmation
200 }
201 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
202 PendingToolUseStatus::Error(ref err) => {
203 ToolUseStatus::Error(err.clone().into())
204 }
205 PendingToolUseStatus::InputStillStreaming => {
206 ToolUseStatus::InputStillStreaming
207 }
208 }
209 } else {
210 ToolUseStatus::Pending
211 }
212 })();
213
214 let (icon, needs_confirmation) =
215 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
216 (
217 tool.icon(),
218 tool.needs_confirmation(&tool_use.input, project, cx),
219 )
220 } else {
221 (IconName::Cog, false)
222 };
223
224 tool_uses.push(ToolUse {
225 id: tool_use.id.clone(),
226 name: tool_use.name.clone().into(),
227 ui_text: self.tool_ui_label(
228 &tool_use.name,
229 &tool_use.input,
230 tool_use.is_input_complete,
231 cx,
232 ),
233 input: tool_use.input.clone(),
234 status,
235 icon,
236 needs_confirmation,
237 })
238 }
239
240 tool_uses
241 }
242
243 pub fn tool_ui_label(
244 &self,
245 tool_name: &str,
246 input: &serde_json::Value,
247 is_input_complete: bool,
248 cx: &App,
249 ) -> SharedString {
250 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
251 if is_input_complete {
252 tool.ui_text(input).into()
253 } else {
254 tool.still_streaming_ui_text(input).into()
255 }
256 } else {
257 format!("Unknown tool {tool_name:?}").into()
258 }
259 }
260
261 pub fn tool_results_for_message(
262 &self,
263 assistant_message_id: MessageId,
264 ) -> Vec<&LanguageModelToolResult> {
265 let Some(tool_uses) = self
266 .tool_uses_by_assistant_message
267 .get(&assistant_message_id)
268 else {
269 return Vec::new();
270 };
271
272 tool_uses
273 .iter()
274 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
275 .collect()
276 }
277
278 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
279 self.tool_uses_by_assistant_message
280 .get(&assistant_message_id)
281 .map_or(false, |results| !results.is_empty())
282 }
283
284 pub fn tool_result(
285 &self,
286 tool_use_id: &LanguageModelToolUseId,
287 ) -> Option<&LanguageModelToolResult> {
288 self.tool_results.get(tool_use_id)
289 }
290
291 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
292 self.tool_result_cards.get(tool_use_id)
293 }
294
295 pub fn insert_tool_result_card(
296 &mut self,
297 tool_use_id: LanguageModelToolUseId,
298 card: AnyToolCard,
299 ) {
300 self.tool_result_cards.insert(tool_use_id, card);
301 }
302
303 pub fn request_tool_use(
304 &mut self,
305 assistant_message_id: MessageId,
306 tool_use: LanguageModelToolUse,
307 metadata: ToolUseMetadata,
308 cx: &App,
309 ) -> Arc<str> {
310 let tool_uses = self
311 .tool_uses_by_assistant_message
312 .entry(assistant_message_id)
313 .or_default();
314
315 let mut existing_tool_use_found = false;
316
317 for existing_tool_use in tool_uses.iter_mut() {
318 if existing_tool_use.id == tool_use.id {
319 *existing_tool_use = tool_use.clone();
320 existing_tool_use_found = true;
321 }
322 }
323
324 if !existing_tool_use_found {
325 tool_uses.push(tool_use.clone());
326 }
327
328 let status = if tool_use.is_input_complete {
329 self.tool_use_metadata_by_id
330 .insert(tool_use.id.clone(), metadata);
331
332 PendingToolUseStatus::Idle
333 } else {
334 PendingToolUseStatus::InputStillStreaming
335 };
336
337 let ui_text: Arc<str> = self
338 .tool_ui_label(
339 &tool_use.name,
340 &tool_use.input,
341 tool_use.is_input_complete,
342 cx,
343 )
344 .into();
345
346 let may_perform_edits = self
347 .tools
348 .read(cx)
349 .tool(&tool_use.name, cx)
350 .is_some_and(|tool| tool.may_perform_edits());
351
352 self.pending_tool_uses_by_id.insert(
353 tool_use.id.clone(),
354 PendingToolUse {
355 assistant_message_id,
356 id: tool_use.id,
357 name: tool_use.name.clone(),
358 ui_text: ui_text.clone(),
359 input: tool_use.input,
360 may_perform_edits,
361 status,
362 },
363 );
364
365 ui_text
366 }
367
368 pub fn run_pending_tool(
369 &mut self,
370 tool_use_id: LanguageModelToolUseId,
371 ui_text: SharedString,
372 task: Task<()>,
373 ) {
374 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
375 tool_use.ui_text = ui_text.into();
376 tool_use.status = PendingToolUseStatus::Running {
377 _task: task.shared(),
378 };
379 }
380 }
381
382 pub fn confirm_tool_use(
383 &mut self,
384 tool_use_id: LanguageModelToolUseId,
385 ui_text: impl Into<Arc<str>>,
386 input: serde_json::Value,
387 request: Arc<LanguageModelRequest>,
388 tool: Arc<dyn Tool>,
389 ) {
390 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
391 let ui_text = ui_text.into();
392 tool_use.ui_text = ui_text.clone();
393 let confirmation = Confirmation {
394 tool_use_id,
395 input,
396 request,
397 tool,
398 ui_text,
399 };
400 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
401 }
402 }
403
404 pub fn insert_tool_output(
405 &mut self,
406 tool_use_id: LanguageModelToolUseId,
407 tool_name: Arc<str>,
408 output: Result<ToolResultOutput>,
409 configured_model: Option<&ConfiguredModel>,
410 completion_mode: CompletionMode,
411 ) -> Option<PendingToolUse> {
412 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
413
414 telemetry::event!(
415 "Agent Tool Finished",
416 model = metadata
417 .as_ref()
418 .map(|metadata| metadata.model.telemetry_id()),
419 model_provider = metadata
420 .as_ref()
421 .map(|metadata| metadata.model.provider_id().to_string()),
422 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
423 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
424 tool_name,
425 success = output.is_ok()
426 );
427
428 match output {
429 Ok(output) => {
430 let tool_result = output.content;
431 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
432
433 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
434
435 // Protect from overly large output
436 let tool_output_limit = configured_model
437 .map(|model| {
438 model.model.max_token_count_for_mode(completion_mode.into()) as usize
439 * BYTES_PER_TOKEN_ESTIMATE
440 })
441 .unwrap_or(usize::MAX);
442
443 let content = match tool_result {
444 ToolResultContent::Text(text) => {
445 let text = if text.len() < tool_output_limit {
446 text
447 } else {
448 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
449 format!(
450 "Tool result too long. The first {} bytes:\n\n{}",
451 truncated.len(),
452 truncated
453 )
454 };
455 LanguageModelToolResultContent::Text(text.into())
456 }
457 ToolResultContent::Image(language_model_image) => {
458 if language_model_image.estimate_tokens() < tool_output_limit {
459 LanguageModelToolResultContent::Image(language_model_image)
460 } else {
461 self.tool_results.insert(
462 tool_use_id.clone(),
463 LanguageModelToolResult {
464 tool_use_id: tool_use_id.clone(),
465 tool_name,
466 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
467 is_error: true,
468 output: None,
469 },
470 );
471
472 return old_use;
473 }
474 }
475 };
476
477 self.tool_results.insert(
478 tool_use_id.clone(),
479 LanguageModelToolResult {
480 tool_use_id: tool_use_id.clone(),
481 tool_name,
482 content,
483 is_error: false,
484 output: output.output,
485 },
486 );
487
488 old_use
489 }
490 Err(err) => {
491 self.tool_results.insert(
492 tool_use_id.clone(),
493 LanguageModelToolResult {
494 tool_use_id: tool_use_id.clone(),
495 tool_name,
496 content: LanguageModelToolResultContent::Text(err.to_string().into()),
497 is_error: true,
498 output: None,
499 },
500 );
501
502 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
503 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
504 }
505
506 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
507 }
508 }
509 }
510
511 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
512 self.tool_uses_by_assistant_message
513 .contains_key(&assistant_message_id)
514 }
515
516 pub fn tool_results(
517 &self,
518 assistant_message_id: MessageId,
519 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
520 self.tool_uses_by_assistant_message
521 .get(&assistant_message_id)
522 .into_iter()
523 .flatten()
524 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
525 }
526}
527
528#[derive(Debug, Clone)]
529pub struct PendingToolUse {
530 pub id: LanguageModelToolUseId,
531 /// The ID of the Assistant message in which the tool use was requested.
532 #[allow(unused)]
533 pub assistant_message_id: MessageId,
534 pub name: Arc<str>,
535 pub ui_text: Arc<str>,
536 pub input: serde_json::Value,
537 pub status: PendingToolUseStatus,
538 pub may_perform_edits: bool,
539}
540
541#[derive(Debug, Clone)]
542pub struct Confirmation {
543 pub tool_use_id: LanguageModelToolUseId,
544 pub input: serde_json::Value,
545 pub ui_text: Arc<str>,
546 pub request: Arc<LanguageModelRequest>,
547 pub tool: Arc<dyn Tool>,
548}
549
550#[derive(Debug, Clone)]
551pub enum PendingToolUseStatus {
552 InputStillStreaming,
553 Idle,
554 NeedsConfirmation(Arc<Confirmation>),
555 Running { _task: Shared<Task<()>> },
556 Error(#[allow(unused)] Arc<str>),
557}
558
559impl PendingToolUseStatus {
560 pub fn is_idle(&self) -> bool {
561 matches!(self, PendingToolUseStatus::Idle)
562 }
563
564 pub fn is_error(&self) -> bool {
565 matches!(self, PendingToolUseStatus::Error(_))
566 }
567
568 pub fn needs_confirmation(&self) -> bool {
569 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
570 }
571}
572
573#[derive(Clone)]
574pub struct ToolUseMetadata {
575 pub model: Arc<dyn LanguageModel>,
576 pub thread_id: ThreadId,
577 pub prompt_id: PromptId,
578}