1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
34 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
35 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
36 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
37}
38
39impl ToolUseState {
40 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
41 Self {
42 tools,
43 tool_uses_by_assistant_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 tool_result_cards: HashMap::default(),
47 tool_use_metadata_by_id: HashMap::default(),
48 }
49 }
50
51 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
52 ///
53 /// Accepts a function to filter the tools that should be used to populate the state.
54 pub fn from_serialized_messages(
55 tools: Entity<ToolWorkingSet>,
56 messages: &[SerializedMessage],
57 ) -> Self {
58 let mut this = Self::new(tools);
59 let mut tool_names_by_id = HashMap::default();
60
61 for message in messages {
62 match message.role {
63 Role::Assistant => {
64 if !message.tool_uses.is_empty() {
65 let tool_uses = message
66 .tool_uses
67 .iter()
68 .map(|tool_use| LanguageModelToolUse {
69 id: tool_use.id.clone(),
70 name: tool_use.name.clone().into(),
71 raw_input: tool_use.input.to_string(),
72 input: tool_use.input.clone(),
73 is_input_complete: true,
74 })
75 .collect::<Vec<_>>();
76
77 tool_names_by_id.extend(
78 tool_uses
79 .iter()
80 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
81 );
82
83 this.tool_uses_by_assistant_message
84 .insert(message.id, tool_uses);
85
86 for tool_result in &message.tool_results {
87 let tool_use_id = tool_result.tool_use_id.clone();
88 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
89 log::warn!("no tool name found for tool use: {tool_use_id:?}");
90 continue;
91 };
92
93 this.tool_results.insert(
94 tool_use_id.clone(),
95 LanguageModelToolResult {
96 tool_use_id,
97 tool_name: tool_use.clone(),
98 is_error: tool_result.is_error,
99 content: tool_result.content.clone(),
100 },
101 );
102 }
103 }
104 }
105 Role::System | Role::User => {}
106 }
107 }
108
109 this
110 }
111
112 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
113 let mut pending_tools = Vec::new();
114 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
115 self.tool_results.insert(
116 tool_use_id.clone(),
117 LanguageModelToolResult {
118 tool_use_id,
119 tool_name: tool_use.name.clone(),
120 content: "Tool canceled by user".into(),
121 is_error: true,
122 },
123 );
124 pending_tools.push(tool_use.clone());
125 }
126 pending_tools
127 }
128
129 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
130 self.pending_tool_uses_by_id.values().collect()
131 }
132
133 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
134 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
135 return Vec::new();
136 };
137
138 let mut tool_uses = Vec::new();
139
140 for tool_use in tool_uses_for_message.iter() {
141 let tool_result = self.tool_results.get(&tool_use.id);
142
143 let status = (|| {
144 if let Some(tool_result) = tool_result {
145 return if tool_result.is_error {
146 ToolUseStatus::Error(tool_result.content.clone().into())
147 } else {
148 ToolUseStatus::Finished(tool_result.content.clone().into())
149 };
150 }
151
152 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
153 match pending_tool_use.status {
154 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
155 PendingToolUseStatus::NeedsConfirmation { .. } => {
156 ToolUseStatus::NeedsConfirmation
157 }
158 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
159 PendingToolUseStatus::Error(ref err) => {
160 ToolUseStatus::Error(err.clone().into())
161 }
162 PendingToolUseStatus::InputStillStreaming => {
163 ToolUseStatus::InputStillStreaming
164 }
165 }
166 } else {
167 ToolUseStatus::Pending
168 }
169 })();
170
171 let (icon, needs_confirmation) =
172 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
173 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
174 } else {
175 (IconName::Cog, false)
176 };
177
178 tool_uses.push(ToolUse {
179 id: tool_use.id.clone(),
180 name: tool_use.name.clone().into(),
181 ui_text: self.tool_ui_label(
182 &tool_use.name,
183 &tool_use.input,
184 tool_use.is_input_complete,
185 cx,
186 ),
187 input: tool_use.input.clone(),
188 status,
189 icon,
190 needs_confirmation,
191 })
192 }
193
194 tool_uses
195 }
196
197 pub fn tool_ui_label(
198 &self,
199 tool_name: &str,
200 input: &serde_json::Value,
201 is_input_complete: bool,
202 cx: &App,
203 ) -> SharedString {
204 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
205 if is_input_complete {
206 tool.ui_text(input).into()
207 } else {
208 tool.still_streaming_ui_text(input).into()
209 }
210 } else {
211 format!("Unknown tool {tool_name:?}").into()
212 }
213 }
214
215 pub fn tool_results_for_message(
216 &self,
217 assistant_message_id: MessageId,
218 ) -> Vec<&LanguageModelToolResult> {
219 let Some(tool_uses) = self
220 .tool_uses_by_assistant_message
221 .get(&assistant_message_id)
222 else {
223 return Vec::new();
224 };
225
226 tool_uses
227 .iter()
228 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
229 .collect()
230 }
231
232 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
233 self.tool_uses_by_assistant_message
234 .get(&assistant_message_id)
235 .map_or(false, |results| !results.is_empty())
236 }
237
238 pub fn tool_result(
239 &self,
240 tool_use_id: &LanguageModelToolUseId,
241 ) -> Option<&LanguageModelToolResult> {
242 self.tool_results.get(tool_use_id)
243 }
244
245 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
246 self.tool_result_cards.get(tool_use_id)
247 }
248
249 pub fn insert_tool_result_card(
250 &mut self,
251 tool_use_id: LanguageModelToolUseId,
252 card: AnyToolCard,
253 ) {
254 self.tool_result_cards.insert(tool_use_id, card);
255 }
256
257 pub fn request_tool_use(
258 &mut self,
259 assistant_message_id: MessageId,
260 tool_use: LanguageModelToolUse,
261 metadata: ToolUseMetadata,
262 cx: &App,
263 ) -> Arc<str> {
264 let tool_uses = self
265 .tool_uses_by_assistant_message
266 .entry(assistant_message_id)
267 .or_default();
268
269 let mut existing_tool_use_found = false;
270
271 for existing_tool_use in tool_uses.iter_mut() {
272 if existing_tool_use.id == tool_use.id {
273 *existing_tool_use = tool_use.clone();
274 existing_tool_use_found = true;
275 }
276 }
277
278 if !existing_tool_use_found {
279 tool_uses.push(tool_use.clone());
280 }
281
282 let status = if tool_use.is_input_complete {
283 self.tool_use_metadata_by_id
284 .insert(tool_use.id.clone(), metadata);
285
286 PendingToolUseStatus::Idle
287 } else {
288 PendingToolUseStatus::InputStillStreaming
289 };
290
291 let ui_text: Arc<str> = self
292 .tool_ui_label(
293 &tool_use.name,
294 &tool_use.input,
295 tool_use.is_input_complete,
296 cx,
297 )
298 .into();
299
300 self.pending_tool_uses_by_id.insert(
301 tool_use.id.clone(),
302 PendingToolUse {
303 assistant_message_id,
304 id: tool_use.id,
305 name: tool_use.name.clone(),
306 ui_text: ui_text.clone(),
307 input: tool_use.input,
308 status,
309 },
310 );
311
312 ui_text
313 }
314
315 pub fn run_pending_tool(
316 &mut self,
317 tool_use_id: LanguageModelToolUseId,
318 ui_text: SharedString,
319 task: Task<()>,
320 ) {
321 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
322 tool_use.ui_text = ui_text.into();
323 tool_use.status = PendingToolUseStatus::Running {
324 _task: task.shared(),
325 };
326 }
327 }
328
329 pub fn confirm_tool_use(
330 &mut self,
331 tool_use_id: LanguageModelToolUseId,
332 ui_text: impl Into<Arc<str>>,
333 input: serde_json::Value,
334 messages: Arc<Vec<LanguageModelRequestMessage>>,
335 tool: Arc<dyn Tool>,
336 ) {
337 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
338 let ui_text = ui_text.into();
339 tool_use.ui_text = ui_text.clone();
340 let confirmation = Confirmation {
341 tool_use_id,
342 input,
343 messages,
344 tool,
345 ui_text,
346 };
347 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
348 }
349 }
350
351 pub fn insert_tool_output(
352 &mut self,
353 tool_use_id: LanguageModelToolUseId,
354 tool_name: Arc<str>,
355 output: Result<String>,
356 configured_model: Option<&ConfiguredModel>,
357 ) -> Option<PendingToolUse> {
358 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
359
360 telemetry::event!(
361 "Agent Tool Finished",
362 model = metadata
363 .as_ref()
364 .map(|metadata| metadata.model.telemetry_id()),
365 model_provider = metadata
366 .as_ref()
367 .map(|metadata| metadata.model.provider_id().to_string()),
368 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
369 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
370 tool_name,
371 success = output.is_ok()
372 );
373
374 match output {
375 Ok(tool_result) => {
376 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
377
378 // Protect from clearly large output
379 let tool_output_limit = configured_model
380 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
381 .unwrap_or(usize::MAX);
382
383 let tool_result = if tool_result.len() <= tool_output_limit {
384 tool_result
385 } else {
386 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
387
388 format!(
389 "Tool result too long. The first {} bytes:\n\n{}",
390 truncated.len(),
391 truncated
392 )
393 };
394
395 self.tool_results.insert(
396 tool_use_id.clone(),
397 LanguageModelToolResult {
398 tool_use_id: tool_use_id.clone(),
399 tool_name,
400 content: tool_result.into(),
401 is_error: false,
402 },
403 );
404 self.pending_tool_uses_by_id.remove(&tool_use_id)
405 }
406 Err(err) => {
407 self.tool_results.insert(
408 tool_use_id.clone(),
409 LanguageModelToolResult {
410 tool_use_id: tool_use_id.clone(),
411 tool_name,
412 content: err.to_string().into(),
413 is_error: true,
414 },
415 );
416
417 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
418 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
419 }
420
421 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
422 }
423 }
424 }
425
426 pub fn attach_tool_uses(
427 &self,
428 message_id: MessageId,
429 request_message: &mut LanguageModelRequestMessage,
430 ) {
431 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
432 for tool_use in tool_uses {
433 if self.tool_results.contains_key(&tool_use.id) {
434 // Do not send tool uses until they are completed
435 request_message
436 .content
437 .push(MessageContent::ToolUse(tool_use.clone()));
438 } else {
439 log::debug!(
440 "skipped tool use {:?} because it is still pending",
441 tool_use
442 );
443 }
444 }
445 }
446 }
447
448 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
449 self.tool_uses_by_assistant_message
450 .contains_key(&assistant_message_id)
451 }
452
453 pub fn tool_results_message(
454 &self,
455 assistant_message_id: MessageId,
456 ) -> Option<LanguageModelRequestMessage> {
457 let tool_uses = self
458 .tool_uses_by_assistant_message
459 .get(&assistant_message_id)?;
460
461 if tool_uses.is_empty() {
462 return None;
463 }
464
465 let mut request_message = LanguageModelRequestMessage {
466 role: Role::User,
467 content: vec![],
468 cache: false,
469 };
470
471 for tool_use in tool_uses {
472 if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
473 request_message
474 .content
475 .push(MessageContent::ToolResult(LanguageModelToolResult {
476 tool_use_id: tool_use.id.clone(),
477 tool_name: tool_result.tool_name.clone(),
478 is_error: tool_result.is_error,
479 content: if tool_result.content.is_empty() {
480 // Surprisingly, the API fails if we return an empty string here.
481 // It thinks we are sending a tool use without a tool result.
482 "<Tool returned an empty string>".into()
483 } else {
484 tool_result.content.clone()
485 },
486 }));
487 }
488 }
489
490 Some(request_message)
491 }
492}
493
494#[derive(Debug, Clone)]
495pub struct PendingToolUse {
496 pub id: LanguageModelToolUseId,
497 /// The ID of the Assistant message in which the tool use was requested.
498 #[allow(unused)]
499 pub assistant_message_id: MessageId,
500 pub name: Arc<str>,
501 pub ui_text: Arc<str>,
502 pub input: serde_json::Value,
503 pub status: PendingToolUseStatus,
504}
505
506#[derive(Debug, Clone)]
507pub struct Confirmation {
508 pub tool_use_id: LanguageModelToolUseId,
509 pub input: serde_json::Value,
510 pub ui_text: Arc<str>,
511 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
512 pub tool: Arc<dyn Tool>,
513}
514
515#[derive(Debug, Clone)]
516pub enum PendingToolUseStatus {
517 InputStillStreaming,
518 Idle,
519 NeedsConfirmation(Arc<Confirmation>),
520 Running { _task: Shared<Task<()>> },
521 Error(#[allow(unused)] Arc<str>),
522}
523
524impl PendingToolUseStatus {
525 pub fn is_idle(&self) -> bool {
526 matches!(self, PendingToolUseStatus::Idle)
527 }
528
529 pub fn is_error(&self) -> bool {
530 matches!(self, PendingToolUseStatus::Error(_))
531 }
532
533 pub fn needs_confirmation(&self) -> bool {
534 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
535 }
536}
537
538#[derive(Clone)]
539pub struct ToolUseMetadata {
540 pub model: Arc<dyn LanguageModel>,
541 pub thread_id: ThreadId,
542 pub prompt_id: PromptId,
543}