1use crate::{
2 thread::{MessageId, PromptId, ThreadId},
3 thread_store::SerializedMessage,
4};
5use agent_settings::CompletionMode;
6use anyhow::Result;
7use assistant_tool::{
8 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
9};
10use collections::HashMap;
11use futures::{FutureExt as _, future::Shared};
12use gpui::{App, Entity, SharedString, Task, Window};
13use icons::IconName;
14use language_model::{
15 ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
16 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
17 LanguageModelToolUseId, Role,
18};
19use project::Project;
20use std::sync::Arc;
21use util::truncate_lines_to_byte_limit;
22
23#[derive(Debug)]
24pub struct ToolUse {
25 pub id: LanguageModelToolUseId,
26 pub name: SharedString,
27 pub ui_text: SharedString,
28 pub status: ToolUseStatus,
29 pub input: serde_json::Value,
30 pub icon: icons::IconName,
31 pub needs_confirmation: bool,
32}
33
34pub struct ToolUseState {
35 tools: Entity<ToolWorkingSet>,
36 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
37 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
38 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
39 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
40 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
41}
42
43impl ToolUseState {
44 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
45 Self {
46 tools,
47 tool_uses_by_assistant_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 ///
59 /// If `window` is `None` (e.g., when in headless mode or when running evals),
60 /// tool cards won't be deserialized
61 pub fn from_serialized_messages(
62 tools: Entity<ToolWorkingSet>,
63 messages: &[SerializedMessage],
64 project: Entity<Project>,
65 window: Option<&mut Window>, // None in headless mode
66 cx: &mut App,
67 ) -> Self {
68 let mut this = Self::new(tools);
69 let mut tool_names_by_id = HashMap::default();
70 let mut window = window;
71
72 for message in messages {
73 match message.role {
74 Role::Assistant => {
75 if !message.tool_uses.is_empty() {
76 let tool_uses = message
77 .tool_uses
78 .iter()
79 .map(|tool_use| LanguageModelToolUse {
80 id: tool_use.id.clone(),
81 name: tool_use.name.clone().into(),
82 raw_input: tool_use.input.to_string(),
83 input: tool_use.input.clone(),
84 is_input_complete: true,
85 })
86 .collect::<Vec<_>>();
87
88 tool_names_by_id.extend(
89 tool_uses
90 .iter()
91 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
92 );
93
94 this.tool_uses_by_assistant_message
95 .insert(message.id, tool_uses);
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 this.tool_results.insert(
105 tool_use_id.clone(),
106 LanguageModelToolResult {
107 tool_use_id: tool_use_id.clone(),
108 tool_name: tool_use.clone(),
109 is_error: tool_result.is_error,
110 content: tool_result.content.clone(),
111 output: tool_result.output.clone(),
112 },
113 );
114
115 if let Some(window) = &mut window
116 && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
117 && let Some(output) = tool_result.output.clone()
118 && let Some(card) =
119 tool.deserialize_card(output, project.clone(), window, cx)
120 {
121 this.tool_result_cards.insert(tool_use_id, card);
122 }
123 }
124 }
125 }
126 Role::System | Role::User => {}
127 }
128 }
129
130 this
131 }
132
133 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
134 let mut canceled_tool_uses = Vec::new();
135 self.pending_tool_uses_by_id
136 .retain(|tool_use_id, tool_use| {
137 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
138 return true;
139 }
140
141 let content = "Tool canceled by user".into();
142 self.tool_results.insert(
143 tool_use_id.clone(),
144 LanguageModelToolResult {
145 tool_use_id: tool_use_id.clone(),
146 tool_name: tool_use.name.clone(),
147 content,
148 output: None,
149 is_error: true,
150 },
151 );
152 canceled_tool_uses.push(tool_use.clone());
153 false
154 });
155 canceled_tool_uses
156 }
157
158 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
159 self.pending_tool_uses_by_id.values().collect()
160 }
161
162 pub fn tool_uses_for_message(
163 &self,
164 id: MessageId,
165 project: &Entity<Project>,
166 cx: &App,
167 ) -> Vec<ToolUse> {
168 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169 return Vec::new();
170 };
171
172 let mut tool_uses = Vec::new();
173
174 for tool_use in tool_uses_for_message.iter() {
175 let tool_result = self.tool_results.get(&tool_use.id);
176
177 let status = (|| {
178 if let Some(tool_result) = tool_result {
179 let content = tool_result
180 .content
181 .to_str()
182 .map(|str| str.to_owned().into())
183 .unwrap_or_default();
184
185 return if tool_result.is_error {
186 ToolUseStatus::Error(content)
187 } else {
188 ToolUseStatus::Finished(content)
189 };
190 }
191
192 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193 match pending_tool_use.status {
194 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195 PendingToolUseStatus::NeedsConfirmation { .. } => {
196 ToolUseStatus::NeedsConfirmation
197 }
198 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199 PendingToolUseStatus::Error(ref err) => {
200 ToolUseStatus::Error(err.clone().into())
201 }
202 PendingToolUseStatus::InputStillStreaming => {
203 ToolUseStatus::InputStillStreaming
204 }
205 }
206 } else {
207 ToolUseStatus::Pending
208 }
209 })();
210
211 let (icon, needs_confirmation) =
212 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213 (
214 tool.icon(),
215 tool.needs_confirmation(&tool_use.input, project, cx),
216 )
217 } else {
218 (IconName::Cog, false)
219 };
220
221 tool_uses.push(ToolUse {
222 id: tool_use.id.clone(),
223 name: tool_use.name.clone().into(),
224 ui_text: self.tool_ui_label(
225 &tool_use.name,
226 &tool_use.input,
227 tool_use.is_input_complete,
228 cx,
229 ),
230 input: tool_use.input.clone(),
231 status,
232 icon,
233 needs_confirmation,
234 })
235 }
236
237 tool_uses
238 }
239
240 pub fn tool_ui_label(
241 &self,
242 tool_name: &str,
243 input: &serde_json::Value,
244 is_input_complete: bool,
245 cx: &App,
246 ) -> SharedString {
247 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
248 if is_input_complete {
249 tool.ui_text(input).into()
250 } else {
251 tool.still_streaming_ui_text(input).into()
252 }
253 } else {
254 format!("Unknown tool {tool_name:?}").into()
255 }
256 }
257
258 pub fn tool_results_for_message(
259 &self,
260 assistant_message_id: MessageId,
261 ) -> Vec<&LanguageModelToolResult> {
262 let Some(tool_uses) = self
263 .tool_uses_by_assistant_message
264 .get(&assistant_message_id)
265 else {
266 return Vec::new();
267 };
268
269 tool_uses
270 .iter()
271 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
272 .collect()
273 }
274
275 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
276 self.tool_uses_by_assistant_message
277 .get(&assistant_message_id)
278 .is_some_and(|results| !results.is_empty())
279 }
280
281 pub fn tool_result(
282 &self,
283 tool_use_id: &LanguageModelToolUseId,
284 ) -> Option<&LanguageModelToolResult> {
285 self.tool_results.get(tool_use_id)
286 }
287
288 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
289 self.tool_result_cards.get(tool_use_id)
290 }
291
292 pub fn insert_tool_result_card(
293 &mut self,
294 tool_use_id: LanguageModelToolUseId,
295 card: AnyToolCard,
296 ) {
297 self.tool_result_cards.insert(tool_use_id, card);
298 }
299
300 pub fn request_tool_use(
301 &mut self,
302 assistant_message_id: MessageId,
303 tool_use: LanguageModelToolUse,
304 metadata: ToolUseMetadata,
305 cx: &App,
306 ) -> Arc<str> {
307 let tool_uses = self
308 .tool_uses_by_assistant_message
309 .entry(assistant_message_id)
310 .or_default();
311
312 let mut existing_tool_use_found = false;
313
314 for existing_tool_use in tool_uses.iter_mut() {
315 if existing_tool_use.id == tool_use.id {
316 *existing_tool_use = tool_use.clone();
317 existing_tool_use_found = true;
318 }
319 }
320
321 if !existing_tool_use_found {
322 tool_uses.push(tool_use.clone());
323 }
324
325 let status = if tool_use.is_input_complete {
326 self.tool_use_metadata_by_id
327 .insert(tool_use.id.clone(), metadata);
328
329 PendingToolUseStatus::Idle
330 } else {
331 PendingToolUseStatus::InputStillStreaming
332 };
333
334 let ui_text: Arc<str> = self
335 .tool_ui_label(
336 &tool_use.name,
337 &tool_use.input,
338 tool_use.is_input_complete,
339 cx,
340 )
341 .into();
342
343 let may_perform_edits = self
344 .tools
345 .read(cx)
346 .tool(&tool_use.name, cx)
347 .is_some_and(|tool| tool.may_perform_edits());
348
349 self.pending_tool_uses_by_id.insert(
350 tool_use.id.clone(),
351 PendingToolUse {
352 assistant_message_id,
353 id: tool_use.id,
354 name: tool_use.name.clone(),
355 ui_text: ui_text.clone(),
356 input: tool_use.input,
357 may_perform_edits,
358 status,
359 },
360 );
361
362 ui_text
363 }
364
365 pub fn run_pending_tool(
366 &mut self,
367 tool_use_id: LanguageModelToolUseId,
368 ui_text: SharedString,
369 task: Task<()>,
370 ) {
371 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
372 tool_use.ui_text = ui_text.into();
373 tool_use.status = PendingToolUseStatus::Running {
374 _task: task.shared(),
375 };
376 }
377 }
378
379 pub fn confirm_tool_use(
380 &mut self,
381 tool_use_id: LanguageModelToolUseId,
382 ui_text: impl Into<Arc<str>>,
383 input: serde_json::Value,
384 request: Arc<LanguageModelRequest>,
385 tool: Arc<dyn Tool>,
386 ) {
387 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
388 let ui_text = ui_text.into();
389 tool_use.ui_text = ui_text.clone();
390 let confirmation = Confirmation {
391 tool_use_id,
392 input,
393 request,
394 tool,
395 ui_text,
396 };
397 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
398 }
399 }
400
401 pub fn insert_tool_output(
402 &mut self,
403 tool_use_id: LanguageModelToolUseId,
404 tool_name: Arc<str>,
405 output: Result<ToolResultOutput>,
406 configured_model: Option<&ConfiguredModel>,
407 completion_mode: CompletionMode,
408 ) -> Option<PendingToolUse> {
409 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
410
411 telemetry::event!(
412 "Agent Tool Finished",
413 model = metadata
414 .as_ref()
415 .map(|metadata| metadata.model.telemetry_id()),
416 model_provider = metadata
417 .as_ref()
418 .map(|metadata| metadata.model.provider_id().to_string()),
419 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
420 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
421 tool_name,
422 success = output.is_ok()
423 );
424
425 match output {
426 Ok(output) => {
427 let tool_result = output.content;
428 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
429
430 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
431
432 // Protect from overly large output
433 let tool_output_limit = configured_model
434 .map(|model| {
435 model.model.max_token_count_for_mode(completion_mode.into()) as usize
436 * BYTES_PER_TOKEN_ESTIMATE
437 })
438 .unwrap_or(usize::MAX);
439
440 let content = match tool_result {
441 ToolResultContent::Text(text) => {
442 let text = if text.len() < tool_output_limit {
443 text
444 } else {
445 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
446 format!(
447 "Tool result too long. The first {} bytes:\n\n{}",
448 truncated.len(),
449 truncated
450 )
451 };
452 LanguageModelToolResultContent::Text(text.into())
453 }
454 ToolResultContent::Image(language_model_image) => {
455 if language_model_image.estimate_tokens() < tool_output_limit {
456 LanguageModelToolResultContent::Image(language_model_image)
457 } else {
458 self.tool_results.insert(
459 tool_use_id.clone(),
460 LanguageModelToolResult {
461 tool_use_id: tool_use_id.clone(),
462 tool_name,
463 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
464 is_error: true,
465 output: None,
466 },
467 );
468
469 return old_use;
470 }
471 }
472 };
473
474 self.tool_results.insert(
475 tool_use_id.clone(),
476 LanguageModelToolResult {
477 tool_use_id: tool_use_id.clone(),
478 tool_name,
479 content,
480 is_error: false,
481 output: output.output,
482 },
483 );
484
485 old_use
486 }
487 Err(err) => {
488 self.tool_results.insert(
489 tool_use_id.clone(),
490 LanguageModelToolResult {
491 tool_use_id: tool_use_id.clone(),
492 tool_name,
493 content: LanguageModelToolResultContent::Text(err.to_string().into()),
494 is_error: true,
495 output: None,
496 },
497 );
498
499 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
500 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
501 }
502
503 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
504 }
505 }
506 }
507
508 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
509 self.tool_uses_by_assistant_message
510 .contains_key(&assistant_message_id)
511 }
512
513 pub fn tool_results(
514 &self,
515 assistant_message_id: MessageId,
516 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
517 self.tool_uses_by_assistant_message
518 .get(&assistant_message_id)
519 .into_iter()
520 .flatten()
521 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
522 }
523}
524
525#[derive(Debug, Clone)]
526pub struct PendingToolUse {
527 pub id: LanguageModelToolUseId,
528 /// The ID of the Assistant message in which the tool use was requested.
529 #[allow(unused)]
530 pub assistant_message_id: MessageId,
531 pub name: Arc<str>,
532 pub ui_text: Arc<str>,
533 pub input: serde_json::Value,
534 pub status: PendingToolUseStatus,
535 pub may_perform_edits: bool,
536}
537
538#[derive(Debug, Clone)]
539pub struct Confirmation {
540 pub tool_use_id: LanguageModelToolUseId,
541 pub input: serde_json::Value,
542 pub ui_text: Arc<str>,
543 pub request: Arc<LanguageModelRequest>,
544 pub tool: Arc<dyn Tool>,
545}
546
547#[derive(Debug, Clone)]
548pub enum PendingToolUseStatus {
549 InputStillStreaming,
550 Idle,
551 NeedsConfirmation(Arc<Confirmation>),
552 Running { _task: Shared<Task<()>> },
553 Error(#[allow(unused)] Arc<str>),
554}
555
556impl PendingToolUseStatus {
557 pub fn is_idle(&self) -> bool {
558 matches!(self, PendingToolUseStatus::Idle)
559 }
560
561 pub fn is_error(&self) -> bool {
562 matches!(self, PendingToolUseStatus::Error(_))
563 }
564
565 pub fn needs_confirmation(&self) -> bool {
566 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
567 }
568}
569
570#[derive(Clone)]
571pub struct ToolUseMetadata {
572 pub model: Arc<dyn LanguageModel>,
573 pub thread_id: ThreadId,
574 pub prompt_id: PromptId,
575}