1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, Role,
12};
13use project::Project;
14use ui::{IconName, Window};
15use util::truncate_lines_to_byte_limit;
16
17use crate::thread::{MessageId, PromptId, ThreadId};
18use crate::thread_store::SerializedMessage;
19
20#[derive(Debug)]
21pub struct ToolUse {
22 pub id: LanguageModelToolUseId,
23 pub name: SharedString,
24 pub ui_text: SharedString,
25 pub status: ToolUseStatus,
26 pub input: serde_json::Value,
27 pub icon: ui::IconName,
28 pub needs_confirmation: bool,
29}
30
31pub struct ToolUseState {
32 tools: Entity<ToolWorkingSet>,
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
35 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
36 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
37 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
38}
39
40impl ToolUseState {
41 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
42 Self {
43 tools,
44 tool_uses_by_assistant_message: HashMap::default(),
45 tool_results: HashMap::default(),
46 pending_tool_uses_by_id: HashMap::default(),
47 tool_result_cards: HashMap::default(),
48 tool_use_metadata_by_id: HashMap::default(),
49 }
50 }
51
52 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
53 ///
54 /// Accepts a function to filter the tools that should be used to populate the state.
55 ///
56 /// If `window` is `None` (e.g., when in headless mode or when running evals),
57 /// tool cards won't be deserialized
58 pub fn from_serialized_messages(
59 tools: Entity<ToolWorkingSet>,
60 messages: &[SerializedMessage],
61 project: Entity<Project>,
62 window: Option<&mut Window>, // None in headless mode
63 cx: &mut App,
64 ) -> Self {
65 let mut this = Self::new(tools);
66 let mut tool_names_by_id = HashMap::default();
67 let mut window = window;
68
69 for message in messages {
70 match message.role {
71 Role::Assistant => {
72 if !message.tool_uses.is_empty() {
73 let tool_uses = message
74 .tool_uses
75 .iter()
76 .map(|tool_use| LanguageModelToolUse {
77 id: tool_use.id.clone(),
78 name: tool_use.name.clone().into(),
79 raw_input: tool_use.input.to_string(),
80 input: tool_use.input.clone(),
81 is_input_complete: true,
82 })
83 .collect::<Vec<_>>();
84
85 tool_names_by_id.extend(
86 tool_uses
87 .iter()
88 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
89 );
90
91 this.tool_uses_by_assistant_message
92 .insert(message.id, tool_uses);
93
94 for tool_result in &message.tool_results {
95 let tool_use_id = tool_result.tool_use_id.clone();
96 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
97 log::warn!("no tool name found for tool use: {tool_use_id:?}");
98 continue;
99 };
100
101 this.tool_results.insert(
102 tool_use_id.clone(),
103 LanguageModelToolResult {
104 tool_use_id: tool_use_id.clone(),
105 tool_name: tool_use.clone(),
106 is_error: tool_result.is_error,
107 content: tool_result.content.clone(),
108 output: tool_result.output.clone(),
109 },
110 );
111
112 if let Some(window) = &mut window {
113 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
114 if let Some(output) = tool_result.output.clone() {
115 if let Some(card) = tool.deserialize_card(
116 output,
117 project.clone(),
118 window,
119 cx,
120 ) {
121 this.tool_result_cards.insert(tool_use_id, card);
122 }
123 }
124 }
125 }
126 }
127 }
128 }
129 Role::System | Role::User => {}
130 }
131 }
132
133 this
134 }
135
136 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
137 let mut cancelled_tool_uses = Vec::new();
138 self.pending_tool_uses_by_id
139 .retain(|tool_use_id, tool_use| {
140 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
141 return true;
142 }
143
144 let content = "Tool canceled by user".into();
145 self.tool_results.insert(
146 tool_use_id.clone(),
147 LanguageModelToolResult {
148 tool_use_id: tool_use_id.clone(),
149 tool_name: tool_use.name.clone(),
150 content,
151 output: None,
152 is_error: true,
153 },
154 );
155 cancelled_tool_uses.push(tool_use.clone());
156 false
157 });
158 cancelled_tool_uses
159 }
160
161 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
162 self.pending_tool_uses_by_id.values().collect()
163 }
164
165 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
166 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
167 return Vec::new();
168 };
169
170 let mut tool_uses = Vec::new();
171
172 for tool_use in tool_uses_for_message.iter() {
173 let tool_result = self.tool_results.get(&tool_use.id);
174
175 let status = (|| {
176 if let Some(tool_result) = tool_result {
177 return if tool_result.is_error {
178 ToolUseStatus::Error(tool_result.content.clone().into())
179 } else {
180 ToolUseStatus::Finished(tool_result.content.clone().into())
181 };
182 }
183
184 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
185 match pending_tool_use.status {
186 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
187 PendingToolUseStatus::NeedsConfirmation { .. } => {
188 ToolUseStatus::NeedsConfirmation
189 }
190 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
191 PendingToolUseStatus::Error(ref err) => {
192 ToolUseStatus::Error(err.clone().into())
193 }
194 PendingToolUseStatus::InputStillStreaming => {
195 ToolUseStatus::InputStillStreaming
196 }
197 }
198 } else {
199 ToolUseStatus::Pending
200 }
201 })();
202
203 let (icon, needs_confirmation) =
204 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
205 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
206 } else {
207 (IconName::Cog, false)
208 };
209
210 tool_uses.push(ToolUse {
211 id: tool_use.id.clone(),
212 name: tool_use.name.clone().into(),
213 ui_text: self.tool_ui_label(
214 &tool_use.name,
215 &tool_use.input,
216 tool_use.is_input_complete,
217 cx,
218 ),
219 input: tool_use.input.clone(),
220 status,
221 icon,
222 needs_confirmation,
223 })
224 }
225
226 tool_uses
227 }
228
229 pub fn tool_ui_label(
230 &self,
231 tool_name: &str,
232 input: &serde_json::Value,
233 is_input_complete: bool,
234 cx: &App,
235 ) -> SharedString {
236 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
237 if is_input_complete {
238 tool.ui_text(input).into()
239 } else {
240 tool.still_streaming_ui_text(input).into()
241 }
242 } else {
243 format!("Unknown tool {tool_name:?}").into()
244 }
245 }
246
247 pub fn tool_results_for_message(
248 &self,
249 assistant_message_id: MessageId,
250 ) -> Vec<&LanguageModelToolResult> {
251 let Some(tool_uses) = self
252 .tool_uses_by_assistant_message
253 .get(&assistant_message_id)
254 else {
255 return Vec::new();
256 };
257
258 tool_uses
259 .iter()
260 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
261 .collect()
262 }
263
264 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
265 self.tool_uses_by_assistant_message
266 .get(&assistant_message_id)
267 .map_or(false, |results| !results.is_empty())
268 }
269
270 pub fn tool_result(
271 &self,
272 tool_use_id: &LanguageModelToolUseId,
273 ) -> Option<&LanguageModelToolResult> {
274 self.tool_results.get(tool_use_id)
275 }
276
277 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
278 self.tool_result_cards.get(tool_use_id)
279 }
280
281 pub fn insert_tool_result_card(
282 &mut self,
283 tool_use_id: LanguageModelToolUseId,
284 card: AnyToolCard,
285 ) {
286 self.tool_result_cards.insert(tool_use_id, card);
287 }
288
289 pub fn request_tool_use(
290 &mut self,
291 assistant_message_id: MessageId,
292 tool_use: LanguageModelToolUse,
293 metadata: ToolUseMetadata,
294 cx: &App,
295 ) -> Arc<str> {
296 let tool_uses = self
297 .tool_uses_by_assistant_message
298 .entry(assistant_message_id)
299 .or_default();
300
301 let mut existing_tool_use_found = false;
302
303 for existing_tool_use in tool_uses.iter_mut() {
304 if existing_tool_use.id == tool_use.id {
305 *existing_tool_use = tool_use.clone();
306 existing_tool_use_found = true;
307 }
308 }
309
310 if !existing_tool_use_found {
311 tool_uses.push(tool_use.clone());
312 }
313
314 let status = if tool_use.is_input_complete {
315 self.tool_use_metadata_by_id
316 .insert(tool_use.id.clone(), metadata);
317
318 PendingToolUseStatus::Idle
319 } else {
320 PendingToolUseStatus::InputStillStreaming
321 };
322
323 let ui_text: Arc<str> = self
324 .tool_ui_label(
325 &tool_use.name,
326 &tool_use.input,
327 tool_use.is_input_complete,
328 cx,
329 )
330 .into();
331
332 self.pending_tool_uses_by_id.insert(
333 tool_use.id.clone(),
334 PendingToolUse {
335 assistant_message_id,
336 id: tool_use.id,
337 name: tool_use.name.clone(),
338 ui_text: ui_text.clone(),
339 input: tool_use.input,
340 status,
341 },
342 );
343
344 ui_text
345 }
346
347 pub fn run_pending_tool(
348 &mut self,
349 tool_use_id: LanguageModelToolUseId,
350 ui_text: SharedString,
351 task: Task<()>,
352 ) {
353 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
354 tool_use.ui_text = ui_text.into();
355 tool_use.status = PendingToolUseStatus::Running {
356 _task: task.shared(),
357 };
358 }
359 }
360
361 pub fn confirm_tool_use(
362 &mut self,
363 tool_use_id: LanguageModelToolUseId,
364 ui_text: impl Into<Arc<str>>,
365 input: serde_json::Value,
366 request: Arc<LanguageModelRequest>,
367 tool: Arc<dyn Tool>,
368 ) {
369 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
370 let ui_text = ui_text.into();
371 tool_use.ui_text = ui_text.clone();
372 let confirmation = Confirmation {
373 tool_use_id,
374 input,
375 request,
376 tool,
377 ui_text,
378 };
379 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
380 }
381 }
382
383 pub fn insert_tool_output(
384 &mut self,
385 tool_use_id: LanguageModelToolUseId,
386 tool_name: Arc<str>,
387 output: Result<ToolResultOutput>,
388 configured_model: Option<&ConfiguredModel>,
389 ) -> Option<PendingToolUse> {
390 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
391
392 telemetry::event!(
393 "Agent Tool Finished",
394 model = metadata
395 .as_ref()
396 .map(|metadata| metadata.model.telemetry_id()),
397 model_provider = metadata
398 .as_ref()
399 .map(|metadata| metadata.model.provider_id().to_string()),
400 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
401 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
402 tool_name,
403 success = output.is_ok()
404 );
405
406 match output {
407 Ok(output) => {
408 let tool_result = output.content;
409 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
410
411 // Protect from clearly large output
412 let tool_output_limit = configured_model
413 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
414 .unwrap_or(usize::MAX);
415
416 let tool_result = if tool_result.len() <= tool_output_limit {
417 tool_result
418 } else {
419 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
420
421 format!(
422 "Tool result too long. The first {} bytes:\n\n{}",
423 truncated.len(),
424 truncated
425 )
426 };
427
428 self.tool_results.insert(
429 tool_use_id.clone(),
430 LanguageModelToolResult {
431 tool_use_id: tool_use_id.clone(),
432 tool_name,
433 content: tool_result.into(),
434 is_error: false,
435 output: output.output,
436 },
437 );
438 self.pending_tool_uses_by_id.remove(&tool_use_id)
439 }
440 Err(err) => {
441 self.tool_results.insert(
442 tool_use_id.clone(),
443 LanguageModelToolResult {
444 tool_use_id: tool_use_id.clone(),
445 tool_name,
446 content: err.to_string().into(),
447 is_error: true,
448 output: None,
449 },
450 );
451
452 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
453 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
454 }
455
456 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
457 }
458 }
459 }
460
461 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
462 self.tool_uses_by_assistant_message
463 .contains_key(&assistant_message_id)
464 }
465
466 pub fn tool_results(
467 &self,
468 assistant_message_id: MessageId,
469 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
470 self.tool_uses_by_assistant_message
471 .get(&assistant_message_id)
472 .into_iter()
473 .flatten()
474 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
475 }
476}
477
478#[derive(Debug, Clone)]
479pub struct PendingToolUse {
480 pub id: LanguageModelToolUseId,
481 /// The ID of the Assistant message in which the tool use was requested.
482 #[allow(unused)]
483 pub assistant_message_id: MessageId,
484 pub name: Arc<str>,
485 pub ui_text: Arc<str>,
486 pub input: serde_json::Value,
487 pub status: PendingToolUseStatus,
488}
489
490#[derive(Debug, Clone)]
491pub struct Confirmation {
492 pub tool_use_id: LanguageModelToolUseId,
493 pub input: serde_json::Value,
494 pub ui_text: Arc<str>,
495 pub request: Arc<LanguageModelRequest>,
496 pub tool: Arc<dyn Tool>,
497}
498
499#[derive(Debug, Clone)]
500pub enum PendingToolUseStatus {
501 InputStillStreaming,
502 Idle,
503 NeedsConfirmation(Arc<Confirmation>),
504 Running { _task: Shared<Task<()>> },
505 Error(#[allow(unused)] Arc<str>),
506}
507
508impl PendingToolUseStatus {
509 pub fn is_idle(&self) -> bool {
510 matches!(self, PendingToolUseStatus::Idle)
511 }
512
513 pub fn is_error(&self) -> bool {
514 matches!(self, PendingToolUseStatus::Error(_))
515 }
516
517 pub fn needs_confirmation(&self) -> bool {
518 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
519 }
520}
521
522#[derive(Clone)]
523pub struct ToolUseMetadata {
524 pub model: Arc<dyn LanguageModel>,
525 pub thread_id: ThreadId,
526 pub prompt_id: PromptId,
527}