1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{
5 AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
6};
7use collections::HashMap;
8use futures::FutureExt as _;
9use futures::future::Shared;
10use gpui::{App, Entity, SharedString, Task};
11use language_model::{
12 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
13 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
14};
15use project::Project;
16use ui::{IconName, Window};
17use util::truncate_lines_to_byte_limit;
18
19use crate::thread::{MessageId, PromptId, ThreadId};
20use crate::thread_store::SerializedMessage;
21
22#[derive(Debug)]
23pub struct ToolUse {
24 pub id: LanguageModelToolUseId,
25 pub name: SharedString,
26 pub ui_text: SharedString,
27 pub status: ToolUseStatus,
28 pub input: serde_json::Value,
29 pub icon: ui::IconName,
30 pub needs_confirmation: bool,
31}
32
33pub struct ToolUseState {
34 tools: Entity<ToolWorkingSet>,
35 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
36 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
37 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
38 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
39 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_results: HashMap::default(),
48 pending_tool_uses_by_id: HashMap::default(),
49 tool_result_cards: HashMap::default(),
50 tool_use_metadata_by_id: HashMap::default(),
51 }
52 }
53
54 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
55 ///
56 /// Accepts a function to filter the tools that should be used to populate the state.
57 pub fn from_serialized_messages(
58 tools: Entity<ToolWorkingSet>,
59 messages: &[SerializedMessage],
60 project: Entity<Project>,
61 window: &mut Window,
62 cx: &mut App,
63 ) -> Self {
64 let mut this = Self::new(tools);
65 let mut tool_names_by_id = HashMap::default();
66
67 for message in messages {
68 match message.role {
69 Role::Assistant => {
70 if !message.tool_uses.is_empty() {
71 let tool_uses = message
72 .tool_uses
73 .iter()
74 .map(|tool_use| LanguageModelToolUse {
75 id: tool_use.id.clone(),
76 name: tool_use.name.clone().into(),
77 raw_input: tool_use.input.to_string(),
78 input: tool_use.input.clone(),
79 is_input_complete: true,
80 })
81 .collect::<Vec<_>>();
82
83 tool_names_by_id.extend(
84 tool_uses
85 .iter()
86 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
87 );
88
89 this.tool_uses_by_assistant_message
90 .insert(message.id, tool_uses);
91
92 for tool_result in &message.tool_results {
93 let tool_use_id = tool_result.tool_use_id.clone();
94 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
95 log::warn!("no tool name found for tool use: {tool_use_id:?}");
96 continue;
97 };
98
99 this.tool_results.insert(
100 tool_use_id.clone(),
101 LanguageModelToolResult {
102 tool_use_id: tool_use_id.clone(),
103 tool_name: tool_use.clone(),
104 is_error: tool_result.is_error,
105 content: tool_result.content.clone(),
106 output: tool_result.output.clone(),
107 },
108 );
109
110 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
111 if let Some(output) = tool_result.output.clone() {
112 if let Some(card) =
113 tool.deserialize_card(output, project.clone(), window, cx)
114 {
115 this.tool_result_cards.insert(tool_use_id, card);
116 }
117 }
118 }
119 }
120 }
121 }
122 Role::System | Role::User => {}
123 }
124 }
125
126 this
127 }
128
129 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130 let mut cancelled_tool_uses = Vec::new();
131 self.pending_tool_uses_by_id
132 .retain(|tool_use_id, tool_use| {
133 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
134 return true;
135 }
136
137 let content = "Tool canceled by user".into();
138 self.tool_results.insert(
139 tool_use_id.clone(),
140 LanguageModelToolResult {
141 tool_use_id: tool_use_id.clone(),
142 tool_name: tool_use.name.clone(),
143 content,
144 output: None,
145 is_error: true,
146 },
147 );
148 cancelled_tool_uses.push(tool_use.clone());
149 false
150 });
151 cancelled_tool_uses
152 }
153
154 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
155 self.pending_tool_uses_by_id.values().collect()
156 }
157
158 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
159 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
160 return Vec::new();
161 };
162
163 let mut tool_uses = Vec::new();
164
165 for tool_use in tool_uses_for_message.iter() {
166 let tool_result = self.tool_results.get(&tool_use.id);
167
168 let status = (|| {
169 if let Some(tool_result) = tool_result {
170 let content = tool_result
171 .content
172 .to_str()
173 .map(|str| str.to_owned().into())
174 .unwrap_or_default();
175
176 return if tool_result.is_error {
177 ToolUseStatus::Error(content)
178 } else {
179 ToolUseStatus::Finished(content)
180 };
181 }
182
183 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
184 match pending_tool_use.status {
185 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
186 PendingToolUseStatus::NeedsConfirmation { .. } => {
187 ToolUseStatus::NeedsConfirmation
188 }
189 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
190 PendingToolUseStatus::Error(ref err) => {
191 ToolUseStatus::Error(err.clone().into())
192 }
193 PendingToolUseStatus::InputStillStreaming => {
194 ToolUseStatus::InputStillStreaming
195 }
196 }
197 } else {
198 ToolUseStatus::Pending
199 }
200 })();
201
202 let (icon, needs_confirmation) =
203 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
204 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
205 } else {
206 (IconName::Cog, false)
207 };
208
209 tool_uses.push(ToolUse {
210 id: tool_use.id.clone(),
211 name: tool_use.name.clone().into(),
212 ui_text: self.tool_ui_label(
213 &tool_use.name,
214 &tool_use.input,
215 tool_use.is_input_complete,
216 cx,
217 ),
218 input: tool_use.input.clone(),
219 status,
220 icon,
221 needs_confirmation,
222 })
223 }
224
225 tool_uses
226 }
227
228 pub fn tool_ui_label(
229 &self,
230 tool_name: &str,
231 input: &serde_json::Value,
232 is_input_complete: bool,
233 cx: &App,
234 ) -> SharedString {
235 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
236 if is_input_complete {
237 tool.ui_text(input).into()
238 } else {
239 tool.still_streaming_ui_text(input).into()
240 }
241 } else {
242 format!("Unknown tool {tool_name:?}").into()
243 }
244 }
245
246 pub fn tool_results_for_message(
247 &self,
248 assistant_message_id: MessageId,
249 ) -> Vec<&LanguageModelToolResult> {
250 let Some(tool_uses) = self
251 .tool_uses_by_assistant_message
252 .get(&assistant_message_id)
253 else {
254 return Vec::new();
255 };
256
257 tool_uses
258 .iter()
259 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
260 .collect()
261 }
262
263 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
264 self.tool_uses_by_assistant_message
265 .get(&assistant_message_id)
266 .map_or(false, |results| !results.is_empty())
267 }
268
269 pub fn tool_result(
270 &self,
271 tool_use_id: &LanguageModelToolUseId,
272 ) -> Option<&LanguageModelToolResult> {
273 self.tool_results.get(tool_use_id)
274 }
275
276 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
277 self.tool_result_cards.get(tool_use_id)
278 }
279
280 pub fn insert_tool_result_card(
281 &mut self,
282 tool_use_id: LanguageModelToolUseId,
283 card: AnyToolCard,
284 ) {
285 self.tool_result_cards.insert(tool_use_id, card);
286 }
287
288 pub fn request_tool_use(
289 &mut self,
290 assistant_message_id: MessageId,
291 tool_use: LanguageModelToolUse,
292 metadata: ToolUseMetadata,
293 cx: &App,
294 ) -> Arc<str> {
295 let tool_uses = self
296 .tool_uses_by_assistant_message
297 .entry(assistant_message_id)
298 .or_default();
299
300 let mut existing_tool_use_found = false;
301
302 for existing_tool_use in tool_uses.iter_mut() {
303 if existing_tool_use.id == tool_use.id {
304 *existing_tool_use = tool_use.clone();
305 existing_tool_use_found = true;
306 }
307 }
308
309 if !existing_tool_use_found {
310 tool_uses.push(tool_use.clone());
311 }
312
313 let status = if tool_use.is_input_complete {
314 self.tool_use_metadata_by_id
315 .insert(tool_use.id.clone(), metadata);
316
317 PendingToolUseStatus::Idle
318 } else {
319 PendingToolUseStatus::InputStillStreaming
320 };
321
322 let ui_text: Arc<str> = self
323 .tool_ui_label(
324 &tool_use.name,
325 &tool_use.input,
326 tool_use.is_input_complete,
327 cx,
328 )
329 .into();
330
331 self.pending_tool_uses_by_id.insert(
332 tool_use.id.clone(),
333 PendingToolUse {
334 assistant_message_id,
335 id: tool_use.id,
336 name: tool_use.name.clone(),
337 ui_text: ui_text.clone(),
338 input: tool_use.input,
339 status,
340 },
341 );
342
343 ui_text
344 }
345
346 pub fn run_pending_tool(
347 &mut self,
348 tool_use_id: LanguageModelToolUseId,
349 ui_text: SharedString,
350 task: Task<()>,
351 ) {
352 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
353 tool_use.ui_text = ui_text.into();
354 tool_use.status = PendingToolUseStatus::Running {
355 _task: task.shared(),
356 };
357 }
358 }
359
360 pub fn confirm_tool_use(
361 &mut self,
362 tool_use_id: LanguageModelToolUseId,
363 ui_text: impl Into<Arc<str>>,
364 input: serde_json::Value,
365 request: Arc<LanguageModelRequest>,
366 tool: Arc<dyn Tool>,
367 ) {
368 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
369 let ui_text = ui_text.into();
370 tool_use.ui_text = ui_text.clone();
371 let confirmation = Confirmation {
372 tool_use_id,
373 input,
374 request,
375 tool,
376 ui_text,
377 };
378 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
379 }
380 }
381
382 pub fn insert_tool_output(
383 &mut self,
384 tool_use_id: LanguageModelToolUseId,
385 tool_name: Arc<str>,
386 output: Result<ToolResultOutput>,
387 configured_model: Option<&ConfiguredModel>,
388 ) -> Option<PendingToolUse> {
389 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
390
391 telemetry::event!(
392 "Agent Tool Finished",
393 model = metadata
394 .as_ref()
395 .map(|metadata| metadata.model.telemetry_id()),
396 model_provider = metadata
397 .as_ref()
398 .map(|metadata| metadata.model.provider_id().to_string()),
399 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
400 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
401 tool_name,
402 success = output.is_ok()
403 );
404
405 match output {
406 Ok(output) => {
407 let tool_result = output.content;
408 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
409
410 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
411
412 // Protect from overly large output
413 let tool_output_limit = configured_model
414 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
415 .unwrap_or(usize::MAX);
416
417 let content = match tool_result {
418 ToolResultContent::Text(text) => {
419 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
420
421 LanguageModelToolResultContent::Text(
422 format!(
423 "Tool result too long. The first {} bytes:\n\n{}",
424 truncated.len(),
425 truncated
426 )
427 .into(),
428 )
429 }
430 ToolResultContent::Image(language_model_image) => {
431 if language_model_image.estimate_tokens() < tool_output_limit {
432 LanguageModelToolResultContent::Image(language_model_image)
433 } else {
434 self.tool_results.insert(
435 tool_use_id.clone(),
436 LanguageModelToolResult {
437 tool_use_id: tool_use_id.clone(),
438 tool_name,
439 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
440 is_error: true,
441 output: None,
442 },
443 );
444
445 return old_use;
446 }
447 }
448 };
449
450 self.tool_results.insert(
451 tool_use_id.clone(),
452 LanguageModelToolResult {
453 tool_use_id: tool_use_id.clone(),
454 tool_name,
455 content,
456 is_error: false,
457 output: output.output,
458 },
459 );
460
461 old_use
462 }
463 Err(err) => {
464 self.tool_results.insert(
465 tool_use_id.clone(),
466 LanguageModelToolResult {
467 tool_use_id: tool_use_id.clone(),
468 tool_name,
469 content: LanguageModelToolResultContent::Text(err.to_string().into()),
470 is_error: true,
471 output: None,
472 },
473 );
474
475 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
476 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
477 }
478
479 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
480 }
481 }
482 }
483
484 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
485 self.tool_uses_by_assistant_message
486 .contains_key(&assistant_message_id)
487 }
488
489 pub fn tool_results(
490 &self,
491 assistant_message_id: MessageId,
492 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
493 self.tool_uses_by_assistant_message
494 .get(&assistant_message_id)
495 .into_iter()
496 .flatten()
497 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
498 }
499}
500
501#[derive(Debug, Clone)]
502pub struct PendingToolUse {
503 pub id: LanguageModelToolUseId,
504 /// The ID of the Assistant message in which the tool use was requested.
505 #[allow(unused)]
506 pub assistant_message_id: MessageId,
507 pub name: Arc<str>,
508 pub ui_text: Arc<str>,
509 pub input: serde_json::Value,
510 pub status: PendingToolUseStatus,
511}
512
513#[derive(Debug, Clone)]
514pub struct Confirmation {
515 pub tool_use_id: LanguageModelToolUseId,
516 pub input: serde_json::Value,
517 pub ui_text: Arc<str>,
518 pub request: Arc<LanguageModelRequest>,
519 pub tool: Arc<dyn Tool>,
520}
521
522#[derive(Debug, Clone)]
523pub enum PendingToolUseStatus {
524 InputStillStreaming,
525 Idle,
526 NeedsConfirmation(Arc<Confirmation>),
527 Running { _task: Shared<Task<()>> },
528 Error(#[allow(unused)] Arc<str>),
529}
530
531impl PendingToolUseStatus {
532 pub fn is_idle(&self) -> bool {
533 matches!(self, PendingToolUseStatus::Idle)
534 }
535
536 pub fn is_error(&self) -> bool {
537 matches!(self, PendingToolUseStatus::Error(_))
538 }
539
540 pub fn needs_confirmation(&self) -> bool {
541 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
542 }
543}
544
545#[derive(Clone)]
546pub struct ToolUseMetadata {
547 pub model: Arc<dyn LanguageModel>,
548 pub thread_id: ThreadId,
549 pub prompt_id: PromptId,
550}