1use crate::{
2 thread::{MessageId, PromptId, ThreadId},
3 thread_store::SerializedMessage,
4};
5use agent_settings::CompletionMode;
6use anyhow::Result;
7use assistant_tool::{
8 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
9};
10use collections::HashMap;
11use futures::{FutureExt as _, future::Shared};
12use gpui::{App, Entity, SharedString, Task, Window};
13use icons::IconName;
14use language_model::{
15 ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
16 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
17 LanguageModelToolUseId, Role,
18};
19use project::Project;
20use std::sync::Arc;
21use util::truncate_lines_to_byte_limit;
22
23#[derive(Debug)]
24pub struct ToolUse {
25 pub id: LanguageModelToolUseId,
26 pub name: SharedString,
27 pub ui_text: SharedString,
28 pub status: ToolUseStatus,
29 pub input: serde_json::Value,
30 pub icon: icons::IconName,
31 pub needs_confirmation: bool,
32}
33
34pub struct ToolUseState {
35 tools: Entity<ToolWorkingSet>,
36 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
37 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
38 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
39 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
40 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
41}
42
43impl ToolUseState {
44 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
45 Self {
46 tools,
47 tool_uses_by_assistant_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 ///
59 /// If `window` is `None` (e.g., when in headless mode or when running evals),
60 /// tool cards won't be deserialized
61 pub fn from_serialized_messages(
62 tools: Entity<ToolWorkingSet>,
63 messages: &[SerializedMessage],
64 project: Entity<Project>,
65 window: Option<&mut Window>, // None in headless mode
66 cx: &mut App,
67 ) -> Self {
68 let mut this = Self::new(tools);
69 let mut tool_names_by_id = HashMap::default();
70 let mut window = window;
71
72 for message in messages {
73 match message.role {
74 Role::Assistant => {
75 if !message.tool_uses.is_empty() {
76 let tool_uses = message
77 .tool_uses
78 .iter()
79 .map(|tool_use| LanguageModelToolUse {
80 id: tool_use.id.clone(),
81 name: tool_use.name.clone().into(),
82 raw_input: tool_use.input.to_string(),
83 input: tool_use.input.clone(),
84 is_input_complete: true,
85 })
86 .collect::<Vec<_>>();
87
88 tool_names_by_id.extend(
89 tool_uses
90 .iter()
91 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
92 );
93
94 this.tool_uses_by_assistant_message
95 .insert(message.id, tool_uses);
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 this.tool_results.insert(
105 tool_use_id.clone(),
106 LanguageModelToolResult {
107 tool_use_id: tool_use_id.clone(),
108 tool_name: tool_use.clone(),
109 is_error: tool_result.is_error,
110 content: tool_result.content.clone(),
111 output: tool_result.output.clone(),
112 },
113 );
114
115 if let Some(window) = &mut window {
116 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
117 if let Some(output) = tool_result.output.clone() {
118 if let Some(card) = tool.deserialize_card(
119 output,
120 project.clone(),
121 window,
122 cx,
123 ) {
124 this.tool_result_cards.insert(tool_use_id, card);
125 }
126 }
127 }
128 }
129 }
130 }
131 }
132 Role::System | Role::User => {}
133 }
134 }
135
136 this
137 }
138
139 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
140 let mut canceled_tool_uses = Vec::new();
141 self.pending_tool_uses_by_id
142 .retain(|tool_use_id, tool_use| {
143 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
144 return true;
145 }
146
147 let content = "Tool canceled by user".into();
148 self.tool_results.insert(
149 tool_use_id.clone(),
150 LanguageModelToolResult {
151 tool_use_id: tool_use_id.clone(),
152 tool_name: tool_use.name.clone(),
153 content,
154 output: None,
155 is_error: true,
156 },
157 );
158 canceled_tool_uses.push(tool_use.clone());
159 false
160 });
161 canceled_tool_uses
162 }
163
164 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
165 self.pending_tool_uses_by_id.values().collect()
166 }
167
168 pub fn tool_uses_for_message(
169 &self,
170 id: MessageId,
171 project: &Entity<Project>,
172 cx: &App,
173 ) -> Vec<ToolUse> {
174 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
175 return Vec::new();
176 };
177
178 let mut tool_uses = Vec::new();
179
180 for tool_use in tool_uses_for_message.iter() {
181 let tool_result = self.tool_results.get(&tool_use.id);
182
183 let status = (|| {
184 if let Some(tool_result) = tool_result {
185 let content = tool_result
186 .content
187 .to_str()
188 .map(|str| str.to_owned().into())
189 .unwrap_or_default();
190
191 return if tool_result.is_error {
192 ToolUseStatus::Error(content)
193 } else {
194 ToolUseStatus::Finished(content)
195 };
196 }
197
198 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
199 match pending_tool_use.status {
200 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
201 PendingToolUseStatus::NeedsConfirmation { .. } => {
202 ToolUseStatus::NeedsConfirmation
203 }
204 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
205 PendingToolUseStatus::Error(ref err) => {
206 ToolUseStatus::Error(err.clone().into())
207 }
208 PendingToolUseStatus::InputStillStreaming => {
209 ToolUseStatus::InputStillStreaming
210 }
211 }
212 } else {
213 ToolUseStatus::Pending
214 }
215 })();
216
217 let (icon, needs_confirmation) =
218 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
219 (
220 tool.icon(),
221 tool.needs_confirmation(&tool_use.input, project, cx),
222 )
223 } else {
224 (IconName::Cog, false)
225 };
226
227 tool_uses.push(ToolUse {
228 id: tool_use.id.clone(),
229 name: tool_use.name.clone().into(),
230 ui_text: self.tool_ui_label(
231 &tool_use.name,
232 &tool_use.input,
233 tool_use.is_input_complete,
234 cx,
235 ),
236 input: tool_use.input.clone(),
237 status,
238 icon,
239 needs_confirmation,
240 })
241 }
242
243 tool_uses
244 }
245
246 pub fn tool_ui_label(
247 &self,
248 tool_name: &str,
249 input: &serde_json::Value,
250 is_input_complete: bool,
251 cx: &App,
252 ) -> SharedString {
253 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
254 if is_input_complete {
255 tool.ui_text(input).into()
256 } else {
257 tool.still_streaming_ui_text(input).into()
258 }
259 } else {
260 format!("Unknown tool {tool_name:?}").into()
261 }
262 }
263
264 pub fn tool_results_for_message(
265 &self,
266 assistant_message_id: MessageId,
267 ) -> Vec<&LanguageModelToolResult> {
268 let Some(tool_uses) = self
269 .tool_uses_by_assistant_message
270 .get(&assistant_message_id)
271 else {
272 return Vec::new();
273 };
274
275 tool_uses
276 .iter()
277 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
278 .collect()
279 }
280
281 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
282 self.tool_uses_by_assistant_message
283 .get(&assistant_message_id)
284 .map_or(false, |results| !results.is_empty())
285 }
286
287 pub fn tool_result(
288 &self,
289 tool_use_id: &LanguageModelToolUseId,
290 ) -> Option<&LanguageModelToolResult> {
291 self.tool_results.get(tool_use_id)
292 }
293
294 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
295 self.tool_result_cards.get(tool_use_id)
296 }
297
298 pub fn insert_tool_result_card(
299 &mut self,
300 tool_use_id: LanguageModelToolUseId,
301 card: AnyToolCard,
302 ) {
303 self.tool_result_cards.insert(tool_use_id, card);
304 }
305
306 pub fn request_tool_use(
307 &mut self,
308 assistant_message_id: MessageId,
309 tool_use: LanguageModelToolUse,
310 metadata: ToolUseMetadata,
311 cx: &App,
312 ) -> Arc<str> {
313 let tool_uses = self
314 .tool_uses_by_assistant_message
315 .entry(assistant_message_id)
316 .or_default();
317
318 let mut existing_tool_use_found = false;
319
320 for existing_tool_use in tool_uses.iter_mut() {
321 if existing_tool_use.id == tool_use.id {
322 *existing_tool_use = tool_use.clone();
323 existing_tool_use_found = true;
324 }
325 }
326
327 if !existing_tool_use_found {
328 tool_uses.push(tool_use.clone());
329 }
330
331 let status = if tool_use.is_input_complete {
332 self.tool_use_metadata_by_id
333 .insert(tool_use.id.clone(), metadata);
334
335 PendingToolUseStatus::Idle
336 } else {
337 PendingToolUseStatus::InputStillStreaming
338 };
339
340 let ui_text: Arc<str> = self
341 .tool_ui_label(
342 &tool_use.name,
343 &tool_use.input,
344 tool_use.is_input_complete,
345 cx,
346 )
347 .into();
348
349 let may_perform_edits = self
350 .tools
351 .read(cx)
352 .tool(&tool_use.name, cx)
353 .is_some_and(|tool| tool.may_perform_edits());
354
355 self.pending_tool_uses_by_id.insert(
356 tool_use.id.clone(),
357 PendingToolUse {
358 assistant_message_id,
359 id: tool_use.id,
360 name: tool_use.name.clone(),
361 ui_text: ui_text.clone(),
362 input: tool_use.input,
363 may_perform_edits,
364 status,
365 },
366 );
367
368 ui_text
369 }
370
371 pub fn run_pending_tool(
372 &mut self,
373 tool_use_id: LanguageModelToolUseId,
374 ui_text: SharedString,
375 task: Task<()>,
376 ) {
377 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
378 tool_use.ui_text = ui_text.into();
379 tool_use.status = PendingToolUseStatus::Running {
380 _task: task.shared(),
381 };
382 }
383 }
384
385 pub fn confirm_tool_use(
386 &mut self,
387 tool_use_id: LanguageModelToolUseId,
388 ui_text: impl Into<Arc<str>>,
389 input: serde_json::Value,
390 request: Arc<LanguageModelRequest>,
391 tool: Arc<dyn Tool>,
392 ) {
393 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
394 let ui_text = ui_text.into();
395 tool_use.ui_text = ui_text.clone();
396 let confirmation = Confirmation {
397 tool_use_id,
398 input,
399 request,
400 tool,
401 ui_text,
402 };
403 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
404 }
405 }
406
407 pub fn insert_tool_output(
408 &mut self,
409 tool_use_id: LanguageModelToolUseId,
410 tool_name: Arc<str>,
411 output: Result<ToolResultOutput>,
412 configured_model: Option<&ConfiguredModel>,
413 completion_mode: CompletionMode,
414 ) -> Option<PendingToolUse> {
415 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
416
417 telemetry::event!(
418 "Agent Tool Finished",
419 model = metadata
420 .as_ref()
421 .map(|metadata| metadata.model.telemetry_id()),
422 model_provider = metadata
423 .as_ref()
424 .map(|metadata| metadata.model.provider_id().to_string()),
425 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
426 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
427 tool_name,
428 success = output.is_ok()
429 );
430
431 match output {
432 Ok(output) => {
433 let tool_result = output.content;
434 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
435
436 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
437
438 // Protect from overly large output
439 let tool_output_limit = configured_model
440 .map(|model| {
441 model.model.max_token_count_for_mode(completion_mode.into()) as usize
442 * BYTES_PER_TOKEN_ESTIMATE
443 })
444 .unwrap_or(usize::MAX);
445
446 let content = match tool_result {
447 ToolResultContent::Text(text) => {
448 let text = if text.len() < tool_output_limit {
449 text
450 } else {
451 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
452 format!(
453 "Tool result too long. The first {} bytes:\n\n{}",
454 truncated.len(),
455 truncated
456 )
457 };
458 LanguageModelToolResultContent::Text(text.into())
459 }
460 ToolResultContent::Image(language_model_image) => {
461 if language_model_image.estimate_tokens() < tool_output_limit {
462 LanguageModelToolResultContent::Image(language_model_image)
463 } else {
464 self.tool_results.insert(
465 tool_use_id.clone(),
466 LanguageModelToolResult {
467 tool_use_id: tool_use_id.clone(),
468 tool_name,
469 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
470 is_error: true,
471 output: None,
472 },
473 );
474
475 return old_use;
476 }
477 }
478 };
479
480 self.tool_results.insert(
481 tool_use_id.clone(),
482 LanguageModelToolResult {
483 tool_use_id: tool_use_id.clone(),
484 tool_name,
485 content,
486 is_error: false,
487 output: output.output,
488 },
489 );
490
491 old_use
492 }
493 Err(err) => {
494 self.tool_results.insert(
495 tool_use_id.clone(),
496 LanguageModelToolResult {
497 tool_use_id: tool_use_id.clone(),
498 tool_name,
499 content: LanguageModelToolResultContent::Text(err.to_string().into()),
500 is_error: true,
501 output: None,
502 },
503 );
504
505 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
506 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
507 }
508
509 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
510 }
511 }
512 }
513
514 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
515 self.tool_uses_by_assistant_message
516 .contains_key(&assistant_message_id)
517 }
518
519 pub fn tool_results(
520 &self,
521 assistant_message_id: MessageId,
522 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
523 self.tool_uses_by_assistant_message
524 .get(&assistant_message_id)
525 .into_iter()
526 .flatten()
527 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
528 }
529}
530
531#[derive(Debug, Clone)]
532pub struct PendingToolUse {
533 pub id: LanguageModelToolUseId,
534 /// The ID of the Assistant message in which the tool use was requested.
535 #[allow(unused)]
536 pub assistant_message_id: MessageId,
537 pub name: Arc<str>,
538 pub ui_text: Arc<str>,
539 pub input: serde_json::Value,
540 pub status: PendingToolUseStatus,
541 pub may_perform_edits: bool,
542}
543
544#[derive(Debug, Clone)]
545pub struct Confirmation {
546 pub tool_use_id: LanguageModelToolUseId,
547 pub input: serde_json::Value,
548 pub ui_text: Arc<str>,
549 pub request: Arc<LanguageModelRequest>,
550 pub tool: Arc<dyn Tool>,
551}
552
553#[derive(Debug, Clone)]
554pub enum PendingToolUseStatus {
555 InputStillStreaming,
556 Idle,
557 NeedsConfirmation(Arc<Confirmation>),
558 Running { _task: Shared<Task<()>> },
559 Error(#[allow(unused)] Arc<str>),
560}
561
562impl PendingToolUseStatus {
563 pub fn is_idle(&self) -> bool {
564 matches!(self, PendingToolUseStatus::Idle)
565 }
566
567 pub fn is_error(&self) -> bool {
568 matches!(self, PendingToolUseStatus::Error(_))
569 }
570
571 pub fn needs_confirmation(&self) -> bool {
572 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
573 }
574}
575
576#[derive(Clone)]
577pub struct ToolUseMetadata {
578 pub model: Arc<dyn LanguageModel>,
579 pub thread_id: ThreadId,
580 pub prompt_id: PromptId,
581}