1use crate::{
2 thread::{MessageId, PromptId, ThreadId},
3 thread_store::SerializedMessage,
4};
5use anyhow::Result;
6use assistant_tool::{
7 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
8};
9use collections::HashMap;
10use futures::{FutureExt as _, future::Shared};
11use gpui::{App, Entity, SharedString, Task, Window};
12use icons::IconName;
13use language_model::{
14 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
15 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
16};
17use project::Project;
18use std::sync::Arc;
19use util::truncate_lines_to_byte_limit;
20
21#[derive(Debug)]
22pub struct ToolUse {
23 pub id: LanguageModelToolUseId,
24 pub name: SharedString,
25 pub ui_text: SharedString,
26 pub status: ToolUseStatus,
27 pub input: serde_json::Value,
28 pub icon: icons::IconName,
29 pub needs_confirmation: bool,
30}
31
32pub struct ToolUseState {
33 tools: Entity<ToolWorkingSet>,
34 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
35 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
36 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
37 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
38 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
39}
40
41impl ToolUseState {
42 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
43 Self {
44 tools,
45 tool_uses_by_assistant_message: HashMap::default(),
46 tool_results: HashMap::default(),
47 pending_tool_uses_by_id: HashMap::default(),
48 tool_result_cards: HashMap::default(),
49 tool_use_metadata_by_id: HashMap::default(),
50 }
51 }
52
53 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
54 ///
55 /// Accepts a function to filter the tools that should be used to populate the state.
56 ///
57 /// If `window` is `None` (e.g., when in headless mode or when running evals),
58 /// tool cards won't be deserialized
59 pub fn from_serialized_messages(
60 tools: Entity<ToolWorkingSet>,
61 messages: &[SerializedMessage],
62 project: Entity<Project>,
63 window: Option<&mut Window>, // None in headless mode
64 cx: &mut App,
65 ) -> Self {
66 let mut this = Self::new(tools);
67 let mut tool_names_by_id = HashMap::default();
68 let mut window = window;
69
70 for message in messages {
71 match message.role {
72 Role::Assistant => {
73 if !message.tool_uses.is_empty() {
74 let tool_uses = message
75 .tool_uses
76 .iter()
77 .map(|tool_use| LanguageModelToolUse {
78 id: tool_use.id.clone(),
79 name: tool_use.name.clone().into(),
80 raw_input: tool_use.input.to_string(),
81 input: tool_use.input.clone(),
82 is_input_complete: true,
83 })
84 .collect::<Vec<_>>();
85
86 tool_names_by_id.extend(
87 tool_uses
88 .iter()
89 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
90 );
91
92 this.tool_uses_by_assistant_message
93 .insert(message.id, tool_uses);
94
95 for tool_result in &message.tool_results {
96 let tool_use_id = tool_result.tool_use_id.clone();
97 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
98 log::warn!("no tool name found for tool use: {tool_use_id:?}");
99 continue;
100 };
101
102 this.tool_results.insert(
103 tool_use_id.clone(),
104 LanguageModelToolResult {
105 tool_use_id: tool_use_id.clone(),
106 tool_name: tool_use.clone(),
107 is_error: tool_result.is_error,
108 content: tool_result.content.clone(),
109 output: tool_result.output.clone(),
110 },
111 );
112
113 if let Some(window) = &mut window {
114 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
115 if let Some(output) = tool_result.output.clone() {
116 if let Some(card) = tool.deserialize_card(
117 output,
118 project.clone(),
119 window,
120 cx,
121 ) {
122 this.tool_result_cards.insert(tool_use_id, card);
123 }
124 }
125 }
126 }
127 }
128 }
129 }
130 Role::System | Role::User => {}
131 }
132 }
133
134 this
135 }
136
137 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
138 let mut cancelled_tool_uses = Vec::new();
139 self.pending_tool_uses_by_id
140 .retain(|tool_use_id, tool_use| {
141 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
142 return true;
143 }
144
145 let content = "Tool canceled by user".into();
146 self.tool_results.insert(
147 tool_use_id.clone(),
148 LanguageModelToolResult {
149 tool_use_id: tool_use_id.clone(),
150 tool_name: tool_use.name.clone(),
151 content,
152 output: None,
153 is_error: true,
154 },
155 );
156 cancelled_tool_uses.push(tool_use.clone());
157 false
158 });
159 cancelled_tool_uses
160 }
161
162 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
163 self.pending_tool_uses_by_id.values().collect()
164 }
165
166 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
167 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
168 return Vec::new();
169 };
170
171 let mut tool_uses = Vec::new();
172
173 for tool_use in tool_uses_for_message.iter() {
174 let tool_result = self.tool_results.get(&tool_use.id);
175
176 let status = (|| {
177 if let Some(tool_result) = tool_result {
178 let content = tool_result
179 .content
180 .to_str()
181 .map(|str| str.to_owned().into())
182 .unwrap_or_default();
183
184 return if tool_result.is_error {
185 ToolUseStatus::Error(content)
186 } else {
187 ToolUseStatus::Finished(content)
188 };
189 }
190
191 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
192 match pending_tool_use.status {
193 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
194 PendingToolUseStatus::NeedsConfirmation { .. } => {
195 ToolUseStatus::NeedsConfirmation
196 }
197 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
198 PendingToolUseStatus::Error(ref err) => {
199 ToolUseStatus::Error(err.clone().into())
200 }
201 PendingToolUseStatus::InputStillStreaming => {
202 ToolUseStatus::InputStillStreaming
203 }
204 }
205 } else {
206 ToolUseStatus::Pending
207 }
208 })();
209
210 let (icon, needs_confirmation) =
211 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
212 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
213 } else {
214 (IconName::Cog, false)
215 };
216
217 tool_uses.push(ToolUse {
218 id: tool_use.id.clone(),
219 name: tool_use.name.clone().into(),
220 ui_text: self.tool_ui_label(
221 &tool_use.name,
222 &tool_use.input,
223 tool_use.is_input_complete,
224 cx,
225 ),
226 input: tool_use.input.clone(),
227 status,
228 icon,
229 needs_confirmation,
230 })
231 }
232
233 tool_uses
234 }
235
236 pub fn tool_ui_label(
237 &self,
238 tool_name: &str,
239 input: &serde_json::Value,
240 is_input_complete: bool,
241 cx: &App,
242 ) -> SharedString {
243 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
244 if is_input_complete {
245 tool.ui_text(input).into()
246 } else {
247 tool.still_streaming_ui_text(input).into()
248 }
249 } else {
250 format!("Unknown tool {tool_name:?}").into()
251 }
252 }
253
254 pub fn tool_results_for_message(
255 &self,
256 assistant_message_id: MessageId,
257 ) -> Vec<&LanguageModelToolResult> {
258 let Some(tool_uses) = self
259 .tool_uses_by_assistant_message
260 .get(&assistant_message_id)
261 else {
262 return Vec::new();
263 };
264
265 tool_uses
266 .iter()
267 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
268 .collect()
269 }
270
271 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
272 self.tool_uses_by_assistant_message
273 .get(&assistant_message_id)
274 .map_or(false, |results| !results.is_empty())
275 }
276
277 pub fn tool_result(
278 &self,
279 tool_use_id: &LanguageModelToolUseId,
280 ) -> Option<&LanguageModelToolResult> {
281 self.tool_results.get(tool_use_id)
282 }
283
284 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
285 self.tool_result_cards.get(tool_use_id)
286 }
287
288 pub fn insert_tool_result_card(
289 &mut self,
290 tool_use_id: LanguageModelToolUseId,
291 card: AnyToolCard,
292 ) {
293 self.tool_result_cards.insert(tool_use_id, card);
294 }
295
296 pub fn request_tool_use(
297 &mut self,
298 assistant_message_id: MessageId,
299 tool_use: LanguageModelToolUse,
300 metadata: ToolUseMetadata,
301 cx: &App,
302 ) -> Arc<str> {
303 let tool_uses = self
304 .tool_uses_by_assistant_message
305 .entry(assistant_message_id)
306 .or_default();
307
308 let mut existing_tool_use_found = false;
309
310 for existing_tool_use in tool_uses.iter_mut() {
311 if existing_tool_use.id == tool_use.id {
312 *existing_tool_use = tool_use.clone();
313 existing_tool_use_found = true;
314 }
315 }
316
317 if !existing_tool_use_found {
318 tool_uses.push(tool_use.clone());
319 }
320
321 let status = if tool_use.is_input_complete {
322 self.tool_use_metadata_by_id
323 .insert(tool_use.id.clone(), metadata);
324
325 PendingToolUseStatus::Idle
326 } else {
327 PendingToolUseStatus::InputStillStreaming
328 };
329
330 let ui_text: Arc<str> = self
331 .tool_ui_label(
332 &tool_use.name,
333 &tool_use.input,
334 tool_use.is_input_complete,
335 cx,
336 )
337 .into();
338
339 let may_perform_edits = self
340 .tools
341 .read(cx)
342 .tool(&tool_use.name, cx)
343 .is_some_and(|tool| tool.may_perform_edits());
344
345 self.pending_tool_uses_by_id.insert(
346 tool_use.id.clone(),
347 PendingToolUse {
348 assistant_message_id,
349 id: tool_use.id,
350 name: tool_use.name.clone(),
351 ui_text: ui_text.clone(),
352 input: tool_use.input,
353 may_perform_edits,
354 status,
355 },
356 );
357
358 ui_text
359 }
360
361 pub fn run_pending_tool(
362 &mut self,
363 tool_use_id: LanguageModelToolUseId,
364 ui_text: SharedString,
365 task: Task<()>,
366 ) {
367 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
368 tool_use.ui_text = ui_text.into();
369 tool_use.status = PendingToolUseStatus::Running {
370 _task: task.shared(),
371 };
372 }
373 }
374
375 pub fn confirm_tool_use(
376 &mut self,
377 tool_use_id: LanguageModelToolUseId,
378 ui_text: impl Into<Arc<str>>,
379 input: serde_json::Value,
380 request: Arc<LanguageModelRequest>,
381 tool: Arc<dyn Tool>,
382 ) {
383 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
384 let ui_text = ui_text.into();
385 tool_use.ui_text = ui_text.clone();
386 let confirmation = Confirmation {
387 tool_use_id,
388 input,
389 request,
390 tool,
391 ui_text,
392 };
393 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
394 }
395 }
396
397 pub fn insert_tool_output(
398 &mut self,
399 tool_use_id: LanguageModelToolUseId,
400 tool_name: Arc<str>,
401 output: Result<ToolResultOutput>,
402 configured_model: Option<&ConfiguredModel>,
403 ) -> Option<PendingToolUse> {
404 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
405
406 telemetry::event!(
407 "Agent Tool Finished",
408 model = metadata
409 .as_ref()
410 .map(|metadata| metadata.model.telemetry_id()),
411 model_provider = metadata
412 .as_ref()
413 .map(|metadata| metadata.model.provider_id().to_string()),
414 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
415 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
416 tool_name,
417 success = output.is_ok()
418 );
419
420 match output {
421 Ok(output) => {
422 let tool_result = output.content;
423 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
424
425 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
426
427 // Protect from overly large output
428 let tool_output_limit = configured_model
429 .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
430 .unwrap_or(usize::MAX);
431
432 let content = match tool_result {
433 ToolResultContent::Text(text) => {
434 let text = if text.len() < tool_output_limit {
435 text
436 } else {
437 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
438 format!(
439 "Tool result too long. The first {} bytes:\n\n{}",
440 truncated.len(),
441 truncated
442 )
443 };
444 LanguageModelToolResultContent::Text(text.into())
445 }
446 ToolResultContent::Image(language_model_image) => {
447 if language_model_image.estimate_tokens() < tool_output_limit {
448 LanguageModelToolResultContent::Image(language_model_image)
449 } else {
450 self.tool_results.insert(
451 tool_use_id.clone(),
452 LanguageModelToolResult {
453 tool_use_id: tool_use_id.clone(),
454 tool_name,
455 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
456 is_error: true,
457 output: None,
458 },
459 );
460
461 return old_use;
462 }
463 }
464 };
465
466 self.tool_results.insert(
467 tool_use_id.clone(),
468 LanguageModelToolResult {
469 tool_use_id: tool_use_id.clone(),
470 tool_name,
471 content,
472 is_error: false,
473 output: output.output,
474 },
475 );
476
477 old_use
478 }
479 Err(err) => {
480 self.tool_results.insert(
481 tool_use_id.clone(),
482 LanguageModelToolResult {
483 tool_use_id: tool_use_id.clone(),
484 tool_name,
485 content: LanguageModelToolResultContent::Text(err.to_string().into()),
486 is_error: true,
487 output: None,
488 },
489 );
490
491 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
492 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
493 }
494
495 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
496 }
497 }
498 }
499
500 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
501 self.tool_uses_by_assistant_message
502 .contains_key(&assistant_message_id)
503 }
504
505 pub fn tool_results(
506 &self,
507 assistant_message_id: MessageId,
508 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
509 self.tool_uses_by_assistant_message
510 .get(&assistant_message_id)
511 .into_iter()
512 .flatten()
513 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
514 }
515}
516
517#[derive(Debug, Clone)]
518pub struct PendingToolUse {
519 pub id: LanguageModelToolUseId,
520 /// The ID of the Assistant message in which the tool use was requested.
521 #[allow(unused)]
522 pub assistant_message_id: MessageId,
523 pub name: Arc<str>,
524 pub ui_text: Arc<str>,
525 pub input: serde_json::Value,
526 pub status: PendingToolUseStatus,
527 pub may_perform_edits: bool,
528}
529
530#[derive(Debug, Clone)]
531pub struct Confirmation {
532 pub tool_use_id: LanguageModelToolUseId,
533 pub input: serde_json::Value,
534 pub ui_text: Arc<str>,
535 pub request: Arc<LanguageModelRequest>,
536 pub tool: Arc<dyn Tool>,
537}
538
539#[derive(Debug, Clone)]
540pub enum PendingToolUseStatus {
541 InputStillStreaming,
542 Idle,
543 NeedsConfirmation(Arc<Confirmation>),
544 Running { _task: Shared<Task<()>> },
545 Error(#[allow(unused)] Arc<str>),
546}
547
548impl PendingToolUseStatus {
549 pub fn is_idle(&self) -> bool {
550 matches!(self, PendingToolUseStatus::Idle)
551 }
552
553 pub fn is_error(&self) -> bool {
554 matches!(self, PendingToolUseStatus::Error(_))
555 }
556
557 pub fn needs_confirmation(&self) -> bool {
558 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
559 }
560}
561
562#[derive(Clone)]
563pub struct ToolUseMetadata {
564 pub model: Arc<dyn LanguageModel>,
565 pub thread_id: ThreadId,
566 pub prompt_id: PromptId,
567}