1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{
5 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
6};
7use collections::HashMap;
8use futures::FutureExt as _;
9use futures::future::Shared;
10use gpui::{App, Entity, SharedString, Task};
11use language_model::{
12 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
13 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
14};
15use project::Project;
16use ui::{IconName, Window};
17use util::truncate_lines_to_byte_limit;
18
19use crate::thread::{MessageId, PromptId, ThreadId};
20use crate::thread_store::SerializedMessage;
21
22#[derive(Debug)]
23pub struct ToolUse {
24 pub id: LanguageModelToolUseId,
25 pub name: SharedString,
26 pub ui_text: SharedString,
27 pub status: ToolUseStatus,
28 pub input: serde_json::Value,
29 pub icon: ui::IconName,
30 pub needs_confirmation: bool,
31}
32
33pub struct ToolUseState {
34 tools: Entity<ToolWorkingSet>,
35 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
36 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
37 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
38 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
39 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_results: HashMap::default(),
48 pending_tool_uses_by_id: HashMap::default(),
49 tool_result_cards: HashMap::default(),
50 tool_use_metadata_by_id: HashMap::default(),
51 }
52 }
53
54 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
55 ///
56 /// Accepts a function to filter the tools that should be used to populate the state.
57 ///
58 /// If `window` is `None` (e.g., when in headless mode or when running evals),
59 /// tool cards won't be deserialized
60 pub fn from_serialized_messages(
61 tools: Entity<ToolWorkingSet>,
62 messages: &[SerializedMessage],
63 project: Entity<Project>,
64 window: Option<&mut Window>, // None in headless mode
65 cx: &mut App,
66 ) -> Self {
67 let mut this = Self::new(tools);
68 let mut tool_names_by_id = HashMap::default();
69 let mut window = window;
70
71 for message in messages {
72 match message.role {
73 Role::Assistant => {
74 if !message.tool_uses.is_empty() {
75 let tool_uses = message
76 .tool_uses
77 .iter()
78 .map(|tool_use| LanguageModelToolUse {
79 id: tool_use.id.clone(),
80 name: tool_use.name.clone().into(),
81 raw_input: tool_use.input.to_string(),
82 input: tool_use.input.clone(),
83 is_input_complete: true,
84 })
85 .collect::<Vec<_>>();
86
87 tool_names_by_id.extend(
88 tool_uses
89 .iter()
90 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
91 );
92
93 this.tool_uses_by_assistant_message
94 .insert(message.id, tool_uses);
95
96 for tool_result in &message.tool_results {
97 let tool_use_id = tool_result.tool_use_id.clone();
98 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
99 log::warn!("no tool name found for tool use: {tool_use_id:?}");
100 continue;
101 };
102
103 this.tool_results.insert(
104 tool_use_id.clone(),
105 LanguageModelToolResult {
106 tool_use_id: tool_use_id.clone(),
107 tool_name: tool_use.clone(),
108 is_error: tool_result.is_error,
109 content: tool_result.content.clone(),
110 output: tool_result.output.clone(),
111 },
112 );
113
114 if let Some(window) = &mut window {
115 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
116 if let Some(output) = tool_result.output.clone() {
117 if let Some(card) = tool.deserialize_card(
118 output,
119 project.clone(),
120 window,
121 cx,
122 ) {
123 this.tool_result_cards.insert(tool_use_id, card);
124 }
125 }
126 }
127 }
128 }
129 }
130 }
131 Role::System | Role::User => {}
132 }
133 }
134
135 this
136 }
137
138 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
139 let mut cancelled_tool_uses = Vec::new();
140 self.pending_tool_uses_by_id
141 .retain(|tool_use_id, tool_use| {
142 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
143 return true;
144 }
145
146 let content = "Tool canceled by user".into();
147 self.tool_results.insert(
148 tool_use_id.clone(),
149 LanguageModelToolResult {
150 tool_use_id: tool_use_id.clone(),
151 tool_name: tool_use.name.clone(),
152 content,
153 output: None,
154 is_error: true,
155 },
156 );
157 cancelled_tool_uses.push(tool_use.clone());
158 false
159 });
160 cancelled_tool_uses
161 }
162
163 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164 self.pending_tool_uses_by_id.values().collect()
165 }
166
167 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169 return Vec::new();
170 };
171
172 let mut tool_uses = Vec::new();
173
174 for tool_use in tool_uses_for_message.iter() {
175 let tool_result = self.tool_results.get(&tool_use.id);
176
177 let status = (|| {
178 if let Some(tool_result) = tool_result {
179 let content = tool_result
180 .content
181 .to_str()
182 .map(|str| str.to_owned().into())
183 .unwrap_or_default();
184
185 return if tool_result.is_error {
186 ToolUseStatus::Error(content)
187 } else {
188 ToolUseStatus::Finished(content)
189 };
190 }
191
192 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193 match pending_tool_use.status {
194 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195 PendingToolUseStatus::NeedsConfirmation { .. } => {
196 ToolUseStatus::NeedsConfirmation
197 }
198 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199 PendingToolUseStatus::Error(ref err) => {
200 ToolUseStatus::Error(err.clone().into())
201 }
202 PendingToolUseStatus::InputStillStreaming => {
203 ToolUseStatus::InputStillStreaming
204 }
205 }
206 } else {
207 ToolUseStatus::Pending
208 }
209 })();
210
211 let (icon, needs_confirmation) =
212 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
214 } else {
215 (IconName::Cog, false)
216 };
217
218 tool_uses.push(ToolUse {
219 id: tool_use.id.clone(),
220 name: tool_use.name.clone().into(),
221 ui_text: self.tool_ui_label(
222 &tool_use.name,
223 &tool_use.input,
224 tool_use.is_input_complete,
225 cx,
226 ),
227 input: tool_use.input.clone(),
228 status,
229 icon,
230 needs_confirmation,
231 })
232 }
233
234 tool_uses
235 }
236
237 pub fn tool_ui_label(
238 &self,
239 tool_name: &str,
240 input: &serde_json::Value,
241 is_input_complete: bool,
242 cx: &App,
243 ) -> SharedString {
244 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
245 if is_input_complete {
246 tool.ui_text(input).into()
247 } else {
248 tool.still_streaming_ui_text(input).into()
249 }
250 } else {
251 format!("Unknown tool {tool_name:?}").into()
252 }
253 }
254
255 pub fn tool_results_for_message(
256 &self,
257 assistant_message_id: MessageId,
258 ) -> Vec<&LanguageModelToolResult> {
259 let Some(tool_uses) = self
260 .tool_uses_by_assistant_message
261 .get(&assistant_message_id)
262 else {
263 return Vec::new();
264 };
265
266 tool_uses
267 .iter()
268 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
269 .collect()
270 }
271
272 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
273 self.tool_uses_by_assistant_message
274 .get(&assistant_message_id)
275 .map_or(false, |results| !results.is_empty())
276 }
277
278 pub fn tool_result(
279 &self,
280 tool_use_id: &LanguageModelToolUseId,
281 ) -> Option<&LanguageModelToolResult> {
282 self.tool_results.get(tool_use_id)
283 }
284
285 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
286 self.tool_result_cards.get(tool_use_id)
287 }
288
289 pub fn insert_tool_result_card(
290 &mut self,
291 tool_use_id: LanguageModelToolUseId,
292 card: AnyToolCard,
293 ) {
294 self.tool_result_cards.insert(tool_use_id, card);
295 }
296
297 pub fn request_tool_use(
298 &mut self,
299 assistant_message_id: MessageId,
300 tool_use: LanguageModelToolUse,
301 metadata: ToolUseMetadata,
302 cx: &App,
303 ) -> Arc<str> {
304 let tool_uses = self
305 .tool_uses_by_assistant_message
306 .entry(assistant_message_id)
307 .or_default();
308
309 let mut existing_tool_use_found = false;
310
311 for existing_tool_use in tool_uses.iter_mut() {
312 if existing_tool_use.id == tool_use.id {
313 *existing_tool_use = tool_use.clone();
314 existing_tool_use_found = true;
315 }
316 }
317
318 if !existing_tool_use_found {
319 tool_uses.push(tool_use.clone());
320 }
321
322 let status = if tool_use.is_input_complete {
323 self.tool_use_metadata_by_id
324 .insert(tool_use.id.clone(), metadata);
325
326 PendingToolUseStatus::Idle
327 } else {
328 PendingToolUseStatus::InputStillStreaming
329 };
330
331 let ui_text: Arc<str> = self
332 .tool_ui_label(
333 &tool_use.name,
334 &tool_use.input,
335 tool_use.is_input_complete,
336 cx,
337 )
338 .into();
339
340 let may_perform_edits = self
341 .tools
342 .read(cx)
343 .tool(&tool_use.name, cx)
344 .is_some_and(|tool| tool.may_perform_edits());
345
346 self.pending_tool_uses_by_id.insert(
347 tool_use.id.clone(),
348 PendingToolUse {
349 assistant_message_id,
350 id: tool_use.id,
351 name: tool_use.name.clone(),
352 ui_text: ui_text.clone(),
353 input: tool_use.input,
354 may_perform_edits,
355 status,
356 },
357 );
358
359 ui_text
360 }
361
362 pub fn run_pending_tool(
363 &mut self,
364 tool_use_id: LanguageModelToolUseId,
365 ui_text: SharedString,
366 task: Task<()>,
367 ) {
368 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
369 tool_use.ui_text = ui_text.into();
370 tool_use.status = PendingToolUseStatus::Running {
371 _task: task.shared(),
372 };
373 }
374 }
375
376 pub fn confirm_tool_use(
377 &mut self,
378 tool_use_id: LanguageModelToolUseId,
379 ui_text: impl Into<Arc<str>>,
380 input: serde_json::Value,
381 request: Arc<LanguageModelRequest>,
382 tool: Arc<dyn Tool>,
383 ) {
384 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
385 let ui_text = ui_text.into();
386 tool_use.ui_text = ui_text.clone();
387 let confirmation = Confirmation {
388 tool_use_id,
389 input,
390 request,
391 tool,
392 ui_text,
393 };
394 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
395 }
396 }
397
398 pub fn insert_tool_output(
399 &mut self,
400 tool_use_id: LanguageModelToolUseId,
401 tool_name: Arc<str>,
402 output: Result<ToolResultOutput>,
403 configured_model: Option<&ConfiguredModel>,
404 ) -> Option<PendingToolUse> {
405 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
406
407 telemetry::event!(
408 "Agent Tool Finished",
409 model = metadata
410 .as_ref()
411 .map(|metadata| metadata.model.telemetry_id()),
412 model_provider = metadata
413 .as_ref()
414 .map(|metadata| metadata.model.provider_id().to_string()),
415 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
416 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
417 tool_name,
418 success = output.is_ok()
419 );
420
421 match output {
422 Ok(output) => {
423 let tool_result = output.content;
424 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
425
426 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
427
428 // Protect from overly large output
429 let tool_output_limit = configured_model
430 .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
431 .unwrap_or(usize::MAX);
432
433 let content = match tool_result {
434 ToolResultContent::Text(text) => {
435 let text = if text.len() < tool_output_limit {
436 text
437 } else {
438 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
439 format!(
440 "Tool result too long. The first {} bytes:\n\n{}",
441 truncated.len(),
442 truncated
443 )
444 };
445 LanguageModelToolResultContent::Text(text.into())
446 }
447 ToolResultContent::Image(language_model_image) => {
448 if language_model_image.estimate_tokens() < tool_output_limit {
449 LanguageModelToolResultContent::Image(language_model_image)
450 } else {
451 self.tool_results.insert(
452 tool_use_id.clone(),
453 LanguageModelToolResult {
454 tool_use_id: tool_use_id.clone(),
455 tool_name,
456 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
457 is_error: true,
458 output: None,
459 },
460 );
461
462 return old_use;
463 }
464 }
465 };
466
467 self.tool_results.insert(
468 tool_use_id.clone(),
469 LanguageModelToolResult {
470 tool_use_id: tool_use_id.clone(),
471 tool_name,
472 content,
473 is_error: false,
474 output: output.output,
475 },
476 );
477
478 old_use
479 }
480 Err(err) => {
481 self.tool_results.insert(
482 tool_use_id.clone(),
483 LanguageModelToolResult {
484 tool_use_id: tool_use_id.clone(),
485 tool_name,
486 content: LanguageModelToolResultContent::Text(err.to_string().into()),
487 is_error: true,
488 output: None,
489 },
490 );
491
492 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
493 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
494 }
495
496 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
497 }
498 }
499 }
500
501 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
502 self.tool_uses_by_assistant_message
503 .contains_key(&assistant_message_id)
504 }
505
506 pub fn tool_results(
507 &self,
508 assistant_message_id: MessageId,
509 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
510 self.tool_uses_by_assistant_message
511 .get(&assistant_message_id)
512 .into_iter()
513 .flatten()
514 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
515 }
516}
517
518#[derive(Debug, Clone)]
519pub struct PendingToolUse {
520 pub id: LanguageModelToolUseId,
521 /// The ID of the Assistant message in which the tool use was requested.
522 #[allow(unused)]
523 pub assistant_message_id: MessageId,
524 pub name: Arc<str>,
525 pub ui_text: Arc<str>,
526 pub input: serde_json::Value,
527 pub status: PendingToolUseStatus,
528 pub may_perform_edits: bool,
529}
530
531#[derive(Debug, Clone)]
532pub struct Confirmation {
533 pub tool_use_id: LanguageModelToolUseId,
534 pub input: serde_json::Value,
535 pub ui_text: Arc<str>,
536 pub request: Arc<LanguageModelRequest>,
537 pub tool: Arc<dyn Tool>,
538}
539
540#[derive(Debug, Clone)]
541pub enum PendingToolUseStatus {
542 InputStillStreaming,
543 Idle,
544 NeedsConfirmation(Arc<Confirmation>),
545 Running { _task: Shared<Task<()>> },
546 Error(#[allow(unused)] Arc<str>),
547}
548
549impl PendingToolUseStatus {
550 pub fn is_idle(&self) -> bool {
551 matches!(self, PendingToolUseStatus::Idle)
552 }
553
554 pub fn is_error(&self) -> bool {
555 matches!(self, PendingToolUseStatus::Error(_))
556 }
557
558 pub fn needs_confirmation(&self) -> bool {
559 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
560 }
561}
562
563#[derive(Clone)]
564pub struct ToolUseMetadata {
565 pub model: Arc<dyn LanguageModel>,
566 pub thread_id: ThreadId,
567 pub prompt_id: PromptId,
568}