1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
34 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
35 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
36 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
37 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
38}
39
40impl ToolUseState {
41 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
42 Self {
43 tools,
44 tool_uses_by_assistant_message: HashMap::default(),
45 tool_uses_by_user_message: HashMap::default(),
46 tool_results: HashMap::default(),
47 pending_tool_uses_by_id: HashMap::default(),
48 tool_result_cards: HashMap::default(),
49 tool_use_metadata_by_id: HashMap::default(),
50 }
51 }
52
53 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
54 ///
55 /// Accepts a function to filter the tools that should be used to populate the state.
56 pub fn from_serialized_messages(
57 tools: Entity<ToolWorkingSet>,
58 messages: &[SerializedMessage],
59 mut filter_by_tool_name: impl FnMut(&str) -> bool,
60 ) -> Self {
61 let mut this = Self::new(tools);
62 let mut tool_names_by_id = HashMap::default();
63
64 for message in messages {
65 match message.role {
66 Role::Assistant => {
67 if !message.tool_uses.is_empty() {
68 let tool_uses = message
69 .tool_uses
70 .iter()
71 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
72 .map(|tool_use| LanguageModelToolUse {
73 id: tool_use.id.clone(),
74 name: tool_use.name.clone().into(),
75 input: tool_use.input.clone(),
76 })
77 .collect::<Vec<_>>();
78
79 tool_names_by_id.extend(
80 tool_uses
81 .iter()
82 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
83 );
84
85 this.tool_uses_by_assistant_message
86 .insert(message.id, tool_uses);
87 }
88 }
89 Role::User => {
90 if !message.tool_results.is_empty() {
91 let tool_uses_by_user_message = this
92 .tool_uses_by_user_message
93 .entry(message.id)
94 .or_default();
95
96 for tool_result in &message.tool_results {
97 let tool_use_id = tool_result.tool_use_id.clone();
98 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
99 log::warn!("no tool name found for tool use: {tool_use_id:?}");
100 continue;
101 };
102
103 if !(filter_by_tool_name)(tool_use.as_ref()) {
104 continue;
105 }
106
107 tool_uses_by_user_message.push(tool_use_id.clone());
108 this.tool_results.insert(
109 tool_use_id.clone(),
110 LanguageModelToolResult {
111 tool_use_id,
112 tool_name: tool_use.clone(),
113 is_error: tool_result.is_error,
114 content: tool_result.content.clone(),
115 },
116 );
117 }
118 }
119 }
120 Role::System => {}
121 }
122 }
123
124 this
125 }
126
127 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128 let mut pending_tools = Vec::new();
129 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
130 self.tool_results.insert(
131 tool_use_id.clone(),
132 LanguageModelToolResult {
133 tool_use_id,
134 tool_name: tool_use.name.clone(),
135 content: "Tool canceled by user".into(),
136 is_error: true,
137 },
138 );
139 pending_tools.push(tool_use.clone());
140 }
141 pending_tools
142 }
143
144 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
145 self.pending_tool_uses_by_id.values().collect()
146 }
147
148 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
149 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
150 return Vec::new();
151 };
152
153 let mut tool_uses = Vec::new();
154
155 for tool_use in tool_uses_for_message.iter() {
156 let tool_result = self.tool_results.get(&tool_use.id);
157
158 let status = (|| {
159 if let Some(tool_result) = tool_result {
160 return if tool_result.is_error {
161 ToolUseStatus::Error(tool_result.content.clone().into())
162 } else {
163 ToolUseStatus::Finished(tool_result.content.clone().into())
164 };
165 }
166
167 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
168 match pending_tool_use.status {
169 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
170 PendingToolUseStatus::NeedsConfirmation { .. } => {
171 ToolUseStatus::NeedsConfirmation
172 }
173 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
174 PendingToolUseStatus::Error(ref err) => {
175 ToolUseStatus::Error(err.clone().into())
176 }
177 }
178 } else {
179 ToolUseStatus::Pending
180 }
181 })();
182
183 let (icon, needs_confirmation) =
184 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
185 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
186 } else {
187 (IconName::Cog, false)
188 };
189
190 tool_uses.push(ToolUse {
191 id: tool_use.id.clone(),
192 name: tool_use.name.clone().into(),
193 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
194 input: tool_use.input.clone(),
195 status,
196 icon,
197 needs_confirmation,
198 })
199 }
200
201 tool_uses
202 }
203
204 pub fn tool_ui_label(
205 &self,
206 tool_name: &str,
207 input: &serde_json::Value,
208 cx: &App,
209 ) -> SharedString {
210 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
211 tool.ui_text(input).into()
212 } else {
213 format!("Unknown tool {tool_name:?}").into()
214 }
215 }
216
217 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
218 let empty = Vec::new();
219
220 self.tool_uses_by_user_message
221 .get(&message_id)
222 .unwrap_or(&empty)
223 .iter()
224 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
225 .collect()
226 }
227
228 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
229 self.tool_uses_by_user_message
230 .get(&message_id)
231 .map_or(false, |results| !results.is_empty())
232 }
233
234 pub fn tool_result(
235 &self,
236 tool_use_id: &LanguageModelToolUseId,
237 ) -> Option<&LanguageModelToolResult> {
238 self.tool_results.get(tool_use_id)
239 }
240
241 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
242 self.tool_result_cards.get(tool_use_id)
243 }
244
245 pub fn insert_tool_result_card(
246 &mut self,
247 tool_use_id: LanguageModelToolUseId,
248 card: AnyToolCard,
249 ) {
250 self.tool_result_cards.insert(tool_use_id, card);
251 }
252
253 pub fn request_tool_use(
254 &mut self,
255 assistant_message_id: MessageId,
256 tool_use: LanguageModelToolUse,
257 metadata: ToolUseMetadata,
258 cx: &App,
259 ) {
260 self.tool_uses_by_assistant_message
261 .entry(assistant_message_id)
262 .or_default()
263 .push(tool_use.clone());
264
265 self.tool_use_metadata_by_id
266 .insert(tool_use.id.clone(), metadata);
267
268 // The tool use is being requested by the Assistant, so we want to
269 // attach the tool results to the next user message.
270 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
271 self.tool_uses_by_user_message
272 .entry(next_user_message_id)
273 .or_default()
274 .push(tool_use.id.clone());
275
276 self.pending_tool_uses_by_id.insert(
277 tool_use.id.clone(),
278 PendingToolUse {
279 assistant_message_id,
280 id: tool_use.id,
281 name: tool_use.name.clone(),
282 ui_text: self
283 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
284 .into(),
285 input: tool_use.input,
286 status: PendingToolUseStatus::Idle,
287 },
288 );
289 }
290
291 pub fn run_pending_tool(
292 &mut self,
293 tool_use_id: LanguageModelToolUseId,
294 ui_text: SharedString,
295 task: Task<()>,
296 ) {
297 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
298 tool_use.ui_text = ui_text.into();
299 tool_use.status = PendingToolUseStatus::Running {
300 _task: task.shared(),
301 };
302 }
303 }
304
305 pub fn confirm_tool_use(
306 &mut self,
307 tool_use_id: LanguageModelToolUseId,
308 ui_text: impl Into<Arc<str>>,
309 input: serde_json::Value,
310 messages: Arc<Vec<LanguageModelRequestMessage>>,
311 tool: Arc<dyn Tool>,
312 ) {
313 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
314 let ui_text = ui_text.into();
315 tool_use.ui_text = ui_text.clone();
316 let confirmation = Confirmation {
317 tool_use_id,
318 input,
319 messages,
320 tool,
321 ui_text,
322 };
323 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
324 }
325 }
326
327 pub fn insert_tool_output(
328 &mut self,
329 tool_use_id: LanguageModelToolUseId,
330 tool_name: Arc<str>,
331 output: Result<String>,
332 cx: &App,
333 ) -> Option<PendingToolUse> {
334 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
335
336 telemetry::event!(
337 "Agent Tool Finished",
338 model = metadata
339 .as_ref()
340 .map(|metadata| metadata.model.telemetry_id()),
341 model_provider = metadata
342 .as_ref()
343 .map(|metadata| metadata.model.provider_id().to_string()),
344 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
345 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
346 tool_name,
347 success = output.is_ok()
348 );
349
350 match output {
351 Ok(tool_result) => {
352 let model_registry = LanguageModelRegistry::read_global(cx);
353
354 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
355
356 // Protect from clearly large output
357 let tool_output_limit = model_registry
358 .default_model()
359 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
360 .unwrap_or(usize::MAX);
361
362 let tool_result = if tool_result.len() <= tool_output_limit {
363 tool_result
364 } else {
365 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
366
367 format!(
368 "Tool result too long. The first {} bytes:\n\n{}",
369 truncated.len(),
370 truncated
371 )
372 };
373
374 self.tool_results.insert(
375 tool_use_id.clone(),
376 LanguageModelToolResult {
377 tool_use_id: tool_use_id.clone(),
378 tool_name,
379 content: tool_result.into(),
380 is_error: false,
381 },
382 );
383 self.pending_tool_uses_by_id.remove(&tool_use_id)
384 }
385 Err(err) => {
386 self.tool_results.insert(
387 tool_use_id.clone(),
388 LanguageModelToolResult {
389 tool_use_id: tool_use_id.clone(),
390 tool_name,
391 content: err.to_string().into(),
392 is_error: true,
393 },
394 );
395
396 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
397 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
398 }
399
400 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
401 }
402 }
403 }
404
405 pub fn attach_tool_uses(
406 &self,
407 message_id: MessageId,
408 request_message: &mut LanguageModelRequestMessage,
409 ) {
410 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
411 for tool_use in tool_uses {
412 if self.tool_results.contains_key(&tool_use.id) {
413 // Do not send tool uses until they are completed
414 request_message
415 .content
416 .push(MessageContent::ToolUse(tool_use.clone()));
417 } else {
418 log::debug!(
419 "skipped tool use {:?} because it is still pending",
420 tool_use
421 );
422 }
423 }
424 }
425 }
426
427 pub fn attach_tool_results(
428 &self,
429 message_id: MessageId,
430 request_message: &mut LanguageModelRequestMessage,
431 ) {
432 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
433 for tool_use_id in tool_uses {
434 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
435 request_message.content.push(MessageContent::ToolResult(
436 LanguageModelToolResult {
437 tool_use_id: tool_use_id.clone(),
438 tool_name: tool_result.tool_name.clone(),
439 is_error: tool_result.is_error,
440 content: if tool_result.content.is_empty() {
441 // Surprisingly, the API fails if we return an empty string here.
442 // It thinks we are sending a tool use without a tool result.
443 "<Tool returned an empty string>".into()
444 } else {
445 tool_result.content.clone()
446 },
447 },
448 ));
449 }
450 }
451 }
452 }
453}
454
455#[derive(Debug, Clone)]
456pub struct PendingToolUse {
457 pub id: LanguageModelToolUseId,
458 /// The ID of the Assistant message in which the tool use was requested.
459 #[allow(unused)]
460 pub assistant_message_id: MessageId,
461 pub name: Arc<str>,
462 pub ui_text: Arc<str>,
463 pub input: serde_json::Value,
464 pub status: PendingToolUseStatus,
465}
466
467#[derive(Debug, Clone)]
468pub struct Confirmation {
469 pub tool_use_id: LanguageModelToolUseId,
470 pub input: serde_json::Value,
471 pub ui_text: Arc<str>,
472 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
473 pub tool: Arc<dyn Tool>,
474}
475
476#[derive(Debug, Clone)]
477pub enum PendingToolUseStatus {
478 Idle,
479 NeedsConfirmation(Arc<Confirmation>),
480 Running { _task: Shared<Task<()>> },
481 Error(#[allow(unused)] Arc<str>),
482}
483
484impl PendingToolUseStatus {
485 pub fn is_idle(&self) -> bool {
486 matches!(self, PendingToolUseStatus::Idle)
487 }
488
489 pub fn is_error(&self) -> bool {
490 matches!(self, PendingToolUseStatus::Error(_))
491 }
492
493 pub fn needs_confirmation(&self) -> bool {
494 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
495 }
496}
497
498#[derive(Clone)]
499pub struct ToolUseMetadata {
500 pub model: Arc<dyn LanguageModel>,
501 pub thread_id: ThreadId,
502 pub prompt_id: PromptId,
503}