1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{
5 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
6};
7use collections::HashMap;
8use futures::FutureExt as _;
9use futures::future::Shared;
10use gpui::{App, Entity, SharedString, Task};
11use language_model::{
12 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
13 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
14};
15use project::Project;
16use ui::{IconName, Window};
17use util::truncate_lines_to_byte_limit;
18
19use crate::thread::{MessageId, PromptId, ThreadId};
20use crate::thread_store::SerializedMessage;
21
22#[derive(Debug)]
23pub struct ToolUse {
24 pub id: LanguageModelToolUseId,
25 pub name: SharedString,
26 pub ui_text: SharedString,
27 pub status: ToolUseStatus,
28 pub input: serde_json::Value,
29 pub icon: ui::IconName,
30 pub needs_confirmation: bool,
31}
32
33pub struct ToolUseState {
34 tools: Entity<ToolWorkingSet>,
35 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
36 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
37 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
38 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
39 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_results: HashMap::default(),
48 pending_tool_uses_by_id: HashMap::default(),
49 tool_result_cards: HashMap::default(),
50 tool_use_metadata_by_id: HashMap::default(),
51 }
52 }
53
54 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
55 ///
56 /// Accepts a function to filter the tools that should be used to populate the state.
57 ///
58 /// If `window` is `None` (e.g., when in headless mode or when running evals),
59 /// tool cards won't be deserialized
60 pub fn from_serialized_messages(
61 tools: Entity<ToolWorkingSet>,
62 messages: &[SerializedMessage],
63 project: Entity<Project>,
64 window: Option<&mut Window>, // None in headless mode
65 cx: &mut App,
66 ) -> Self {
67 let mut this = Self::new(tools);
68 let mut tool_names_by_id = HashMap::default();
69 let mut window = window;
70
71 for message in messages {
72 match message.role {
73 Role::Assistant => {
74 if !message.tool_uses.is_empty() {
75 let tool_uses = message
76 .tool_uses
77 .iter()
78 .map(|tool_use| LanguageModelToolUse {
79 id: tool_use.id.clone(),
80 name: tool_use.name.clone().into(),
81 raw_input: tool_use.input.to_string(),
82 input: tool_use.input.clone(),
83 is_input_complete: true,
84 })
85 .collect::<Vec<_>>();
86
87 tool_names_by_id.extend(
88 tool_uses
89 .iter()
90 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
91 );
92
93 this.tool_uses_by_assistant_message
94 .insert(message.id, tool_uses);
95
96 for tool_result in &message.tool_results {
97 let tool_use_id = tool_result.tool_use_id.clone();
98 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
99 log::warn!("no tool name found for tool use: {tool_use_id:?}");
100 continue;
101 };
102
103 this.tool_results.insert(
104 tool_use_id.clone(),
105 LanguageModelToolResult {
106 tool_use_id: tool_use_id.clone(),
107 tool_name: tool_use.clone(),
108 is_error: tool_result.is_error,
109 content: tool_result.content.clone(),
110 output: tool_result.output.clone(),
111 },
112 );
113
114 if let Some(window) = &mut window {
115 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
116 if let Some(output) = tool_result.output.clone() {
117 if let Some(card) = tool.deserialize_card(
118 output,
119 project.clone(),
120 window,
121 cx,
122 ) {
123 this.tool_result_cards.insert(tool_use_id, card);
124 }
125 }
126 }
127 }
128 }
129 }
130 }
131 Role::System | Role::User => {}
132 }
133 }
134
135 this
136 }
137
138 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
139 let mut cancelled_tool_uses = Vec::new();
140 self.pending_tool_uses_by_id
141 .retain(|tool_use_id, tool_use| {
142 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
143 return true;
144 }
145
146 let content = "Tool canceled by user".into();
147 self.tool_results.insert(
148 tool_use_id.clone(),
149 LanguageModelToolResult {
150 tool_use_id: tool_use_id.clone(),
151 tool_name: tool_use.name.clone(),
152 content,
153 output: None,
154 is_error: true,
155 },
156 );
157 cancelled_tool_uses.push(tool_use.clone());
158 false
159 });
160 cancelled_tool_uses
161 }
162
163 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164 self.pending_tool_uses_by_id.values().collect()
165 }
166
167 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169 return Vec::new();
170 };
171
172 let mut tool_uses = Vec::new();
173
174 for tool_use in tool_uses_for_message.iter() {
175 let tool_result = self.tool_results.get(&tool_use.id);
176
177 let status = (|| {
178 if let Some(tool_result) = tool_result {
179 let content = tool_result
180 .content
181 .to_str()
182 .map(|str| str.to_owned().into())
183 .unwrap_or_default();
184
185 return if tool_result.is_error {
186 ToolUseStatus::Error(content)
187 } else {
188 ToolUseStatus::Finished(content)
189 };
190 }
191
192 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193 match pending_tool_use.status {
194 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195 PendingToolUseStatus::NeedsConfirmation { .. } => {
196 ToolUseStatus::NeedsConfirmation
197 }
198 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199 PendingToolUseStatus::Error(ref err) => {
200 ToolUseStatus::Error(err.clone().into())
201 }
202 PendingToolUseStatus::InputStillStreaming => {
203 ToolUseStatus::InputStillStreaming
204 }
205 }
206 } else {
207 ToolUseStatus::Pending
208 }
209 })();
210
211 let (icon, needs_confirmation) =
212 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
214 } else {
215 (IconName::Cog, false)
216 };
217
218 tool_uses.push(ToolUse {
219 id: tool_use.id.clone(),
220 name: tool_use.name.clone().into(),
221 ui_text: self.tool_ui_label(
222 &tool_use.name,
223 &tool_use.input,
224 tool_use.is_input_complete,
225 cx,
226 ),
227 input: tool_use.input.clone(),
228 status,
229 icon,
230 needs_confirmation,
231 })
232 }
233
234 tool_uses
235 }
236
237 pub fn tool_ui_label(
238 &self,
239 tool_name: &str,
240 input: &serde_json::Value,
241 is_input_complete: bool,
242 cx: &App,
243 ) -> SharedString {
244 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
245 if is_input_complete {
246 tool.ui_text(input).into()
247 } else {
248 tool.still_streaming_ui_text(input).into()
249 }
250 } else {
251 format!("Unknown tool {tool_name:?}").into()
252 }
253 }
254
255 pub fn tool_results_for_message(
256 &self,
257 assistant_message_id: MessageId,
258 ) -> Vec<&LanguageModelToolResult> {
259 let Some(tool_uses) = self
260 .tool_uses_by_assistant_message
261 .get(&assistant_message_id)
262 else {
263 return Vec::new();
264 };
265
266 tool_uses
267 .iter()
268 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
269 .collect()
270 }
271
272 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
273 self.tool_uses_by_assistant_message
274 .get(&assistant_message_id)
275 .map_or(false, |results| !results.is_empty())
276 }
277
278 pub fn tool_result(
279 &self,
280 tool_use_id: &LanguageModelToolUseId,
281 ) -> Option<&LanguageModelToolResult> {
282 self.tool_results.get(tool_use_id)
283 }
284
285 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
286 self.tool_result_cards.get(tool_use_id)
287 }
288
289 pub fn insert_tool_result_card(
290 &mut self,
291 tool_use_id: LanguageModelToolUseId,
292 card: AnyToolCard,
293 ) {
294 self.tool_result_cards.insert(tool_use_id, card);
295 }
296
297 pub fn request_tool_use(
298 &mut self,
299 assistant_message_id: MessageId,
300 tool_use: LanguageModelToolUse,
301 metadata: ToolUseMetadata,
302 cx: &App,
303 ) -> Arc<str> {
304 let tool_uses = self
305 .tool_uses_by_assistant_message
306 .entry(assistant_message_id)
307 .or_default();
308
309 let mut existing_tool_use_found = false;
310
311 for existing_tool_use in tool_uses.iter_mut() {
312 if existing_tool_use.id == tool_use.id {
313 *existing_tool_use = tool_use.clone();
314 existing_tool_use_found = true;
315 }
316 }
317
318 if !existing_tool_use_found {
319 tool_uses.push(tool_use.clone());
320 }
321
322 let status = if tool_use.is_input_complete {
323 self.tool_use_metadata_by_id
324 .insert(tool_use.id.clone(), metadata);
325
326 PendingToolUseStatus::Idle
327 } else {
328 PendingToolUseStatus::InputStillStreaming
329 };
330
331 let ui_text: Arc<str> = self
332 .tool_ui_label(
333 &tool_use.name,
334 &tool_use.input,
335 tool_use.is_input_complete,
336 cx,
337 )
338 .into();
339
340 self.pending_tool_uses_by_id.insert(
341 tool_use.id.clone(),
342 PendingToolUse {
343 assistant_message_id,
344 id: tool_use.id,
345 name: tool_use.name.clone(),
346 ui_text: ui_text.clone(),
347 input: tool_use.input,
348 status,
349 },
350 );
351
352 ui_text
353 }
354
355 pub fn run_pending_tool(
356 &mut self,
357 tool_use_id: LanguageModelToolUseId,
358 ui_text: SharedString,
359 task: Task<()>,
360 ) {
361 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
362 tool_use.ui_text = ui_text.into();
363 tool_use.status = PendingToolUseStatus::Running {
364 _task: task.shared(),
365 };
366 }
367 }
368
369 pub fn confirm_tool_use(
370 &mut self,
371 tool_use_id: LanguageModelToolUseId,
372 ui_text: impl Into<Arc<str>>,
373 input: serde_json::Value,
374 request: Arc<LanguageModelRequest>,
375 tool: Arc<dyn Tool>,
376 ) {
377 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
378 let ui_text = ui_text.into();
379 tool_use.ui_text = ui_text.clone();
380 let confirmation = Confirmation {
381 tool_use_id,
382 input,
383 request,
384 tool,
385 ui_text,
386 };
387 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
388 }
389 }
390
391 pub fn insert_tool_output(
392 &mut self,
393 tool_use_id: LanguageModelToolUseId,
394 tool_name: Arc<str>,
395 output: Result<ToolResultOutput>,
396 configured_model: Option<&ConfiguredModel>,
397 ) -> Option<PendingToolUse> {
398 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
399
400 telemetry::event!(
401 "Agent Tool Finished",
402 model = metadata
403 .as_ref()
404 .map(|metadata| metadata.model.telemetry_id()),
405 model_provider = metadata
406 .as_ref()
407 .map(|metadata| metadata.model.provider_id().to_string()),
408 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
409 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
410 tool_name,
411 success = output.is_ok()
412 );
413
414 match output {
415 Ok(output) => {
416 let tool_result = output.content;
417 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
418
419 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
420
421 // Protect from overly large output
422 let tool_output_limit = configured_model
423 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
424 .unwrap_or(usize::MAX);
425
426 let content = match tool_result {
427 ToolResultContent::Text(text) => {
428 let text = if text.len() < tool_output_limit {
429 text
430 } else {
431 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
432 format!(
433 "Tool result too long. The first {} bytes:\n\n{}",
434 truncated.len(),
435 truncated
436 )
437 };
438 LanguageModelToolResultContent::Text(text.into())
439 }
440 ToolResultContent::Image(language_model_image) => {
441 if language_model_image.estimate_tokens() < tool_output_limit {
442 LanguageModelToolResultContent::Image(language_model_image)
443 } else {
444 self.tool_results.insert(
445 tool_use_id.clone(),
446 LanguageModelToolResult {
447 tool_use_id: tool_use_id.clone(),
448 tool_name,
449 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
450 is_error: true,
451 output: None,
452 },
453 );
454
455 return old_use;
456 }
457 }
458 };
459
460 self.tool_results.insert(
461 tool_use_id.clone(),
462 LanguageModelToolResult {
463 tool_use_id: tool_use_id.clone(),
464 tool_name,
465 content,
466 is_error: false,
467 output: output.output,
468 },
469 );
470
471 old_use
472 }
473 Err(err) => {
474 self.tool_results.insert(
475 tool_use_id.clone(),
476 LanguageModelToolResult {
477 tool_use_id: tool_use_id.clone(),
478 tool_name,
479 content: LanguageModelToolResultContent::Text(err.to_string().into()),
480 is_error: true,
481 output: None,
482 },
483 );
484
485 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
486 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
487 }
488
489 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
490 }
491 }
492 }
493
494 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
495 self.tool_uses_by_assistant_message
496 .contains_key(&assistant_message_id)
497 }
498
499 pub fn tool_results(
500 &self,
501 assistant_message_id: MessageId,
502 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
503 self.tool_uses_by_assistant_message
504 .get(&assistant_message_id)
505 .into_iter()
506 .flatten()
507 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
508 }
509}
510
511#[derive(Debug, Clone)]
512pub struct PendingToolUse {
513 pub id: LanguageModelToolUseId,
514 /// The ID of the Assistant message in which the tool use was requested.
515 #[allow(unused)]
516 pub assistant_message_id: MessageId,
517 pub name: Arc<str>,
518 pub ui_text: Arc<str>,
519 pub input: serde_json::Value,
520 pub status: PendingToolUseStatus,
521}
522
523#[derive(Debug, Clone)]
524pub struct Confirmation {
525 pub tool_use_id: LanguageModelToolUseId,
526 pub input: serde_json::Value,
527 pub ui_text: Arc<str>,
528 pub request: Arc<LanguageModelRequest>,
529 pub tool: Arc<dyn Tool>,
530}
531
532#[derive(Debug, Clone)]
533pub enum PendingToolUseStatus {
534 InputStillStreaming,
535 Idle,
536 NeedsConfirmation(Arc<Confirmation>),
537 Running { _task: Shared<Task<()>> },
538 Error(#[allow(unused)] Arc<str>),
539}
540
541impl PendingToolUseStatus {
542 pub fn is_idle(&self) -> bool {
543 matches!(self, PendingToolUseStatus::Idle)
544 }
545
546 pub fn is_error(&self) -> bool {
547 matches!(self, PendingToolUseStatus::Error(_))
548 }
549
550 pub fn needs_confirmation(&self) -> bool {
551 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
552 }
553}
554
555#[derive(Clone)]
556pub struct ToolUseMetadata {
557 pub model: Arc<dyn LanguageModel>,
558 pub thread_id: ThreadId,
559 pub prompt_id: PromptId,
560}