1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
34 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
35 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
36 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
37}
38
39impl ToolUseState {
40 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
41 Self {
42 tools,
43 tool_uses_by_assistant_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 tool_result_cards: HashMap::default(),
47 tool_use_metadata_by_id: HashMap::default(),
48 }
49 }
50
51 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
52 ///
53 /// Accepts a function to filter the tools that should be used to populate the state.
54 pub fn from_serialized_messages(
55 tools: Entity<ToolWorkingSet>,
56 messages: &[SerializedMessage],
57 ) -> Self {
58 let mut this = Self::new(tools);
59 let mut tool_names_by_id = HashMap::default();
60
61 for message in messages {
62 match message.role {
63 Role::Assistant => {
64 if !message.tool_uses.is_empty() {
65 let tool_uses = message
66 .tool_uses
67 .iter()
68 .map(|tool_use| LanguageModelToolUse {
69 id: tool_use.id.clone(),
70 name: tool_use.name.clone().into(),
71 input: tool_use.input.clone(),
72 is_input_complete: true,
73 })
74 .collect::<Vec<_>>();
75
76 tool_names_by_id.extend(
77 tool_uses
78 .iter()
79 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
80 );
81
82 this.tool_uses_by_assistant_message
83 .insert(message.id, tool_uses);
84
85 for tool_result in &message.tool_results {
86 let tool_use_id = tool_result.tool_use_id.clone();
87 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
88 log::warn!("no tool name found for tool use: {tool_use_id:?}");
89 continue;
90 };
91
92 this.tool_results.insert(
93 tool_use_id.clone(),
94 LanguageModelToolResult {
95 tool_use_id,
96 tool_name: tool_use.clone(),
97 is_error: tool_result.is_error,
98 content: tool_result.content.clone(),
99 },
100 );
101 }
102 }
103 }
104 Role::System | Role::User => {}
105 }
106 }
107
108 this
109 }
110
111 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
112 let mut pending_tools = Vec::new();
113 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
114 self.tool_results.insert(
115 tool_use_id.clone(),
116 LanguageModelToolResult {
117 tool_use_id,
118 tool_name: tool_use.name.clone(),
119 content: "Tool canceled by user".into(),
120 is_error: true,
121 },
122 );
123 pending_tools.push(tool_use.clone());
124 }
125 pending_tools
126 }
127
128 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
129 self.pending_tool_uses_by_id.values().collect()
130 }
131
132 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
133 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
134 return Vec::new();
135 };
136
137 let mut tool_uses = Vec::new();
138
139 for tool_use in tool_uses_for_message.iter() {
140 let tool_result = self.tool_results.get(&tool_use.id);
141
142 let status = (|| {
143 if let Some(tool_result) = tool_result {
144 return if tool_result.is_error {
145 ToolUseStatus::Error(tool_result.content.clone().into())
146 } else {
147 ToolUseStatus::Finished(tool_result.content.clone().into())
148 };
149 }
150
151 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
152 match pending_tool_use.status {
153 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
154 PendingToolUseStatus::NeedsConfirmation { .. } => {
155 ToolUseStatus::NeedsConfirmation
156 }
157 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
158 PendingToolUseStatus::Error(ref err) => {
159 ToolUseStatus::Error(err.clone().into())
160 }
161 PendingToolUseStatus::InputStillStreaming => {
162 ToolUseStatus::InputStillStreaming
163 }
164 }
165 } else {
166 ToolUseStatus::Pending
167 }
168 })();
169
170 let (icon, needs_confirmation) =
171 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
172 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
173 } else {
174 (IconName::Cog, false)
175 };
176
177 tool_uses.push(ToolUse {
178 id: tool_use.id.clone(),
179 name: tool_use.name.clone().into(),
180 ui_text: self.tool_ui_label(
181 &tool_use.name,
182 &tool_use.input,
183 tool_use.is_input_complete,
184 cx,
185 ),
186 input: tool_use.input.clone(),
187 status,
188 icon,
189 needs_confirmation,
190 })
191 }
192
193 tool_uses
194 }
195
196 pub fn tool_ui_label(
197 &self,
198 tool_name: &str,
199 input: &serde_json::Value,
200 is_input_complete: bool,
201 cx: &App,
202 ) -> SharedString {
203 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
204 if is_input_complete {
205 tool.ui_text(input).into()
206 } else {
207 tool.still_streaming_ui_text(input).into()
208 }
209 } else {
210 format!("Unknown tool {tool_name:?}").into()
211 }
212 }
213
214 pub fn tool_results_for_message(
215 &self,
216 assistant_message_id: MessageId,
217 ) -> Vec<&LanguageModelToolResult> {
218 let Some(tool_uses) = self
219 .tool_uses_by_assistant_message
220 .get(&assistant_message_id)
221 else {
222 return Vec::new();
223 };
224
225 tool_uses
226 .iter()
227 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
228 .collect()
229 }
230
231 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
232 self.tool_uses_by_assistant_message
233 .get(&assistant_message_id)
234 .map_or(false, |results| !results.is_empty())
235 }
236
237 pub fn tool_result(
238 &self,
239 tool_use_id: &LanguageModelToolUseId,
240 ) -> Option<&LanguageModelToolResult> {
241 self.tool_results.get(tool_use_id)
242 }
243
244 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
245 self.tool_result_cards.get(tool_use_id)
246 }
247
248 pub fn insert_tool_result_card(
249 &mut self,
250 tool_use_id: LanguageModelToolUseId,
251 card: AnyToolCard,
252 ) {
253 self.tool_result_cards.insert(tool_use_id, card);
254 }
255
256 pub fn request_tool_use(
257 &mut self,
258 assistant_message_id: MessageId,
259 tool_use: LanguageModelToolUse,
260 metadata: ToolUseMetadata,
261 cx: &App,
262 ) -> Arc<str> {
263 let tool_uses = self
264 .tool_uses_by_assistant_message
265 .entry(assistant_message_id)
266 .or_default();
267
268 let mut existing_tool_use_found = false;
269
270 for existing_tool_use in tool_uses.iter_mut() {
271 if existing_tool_use.id == tool_use.id {
272 *existing_tool_use = tool_use.clone();
273 existing_tool_use_found = true;
274 }
275 }
276
277 if !existing_tool_use_found {
278 tool_uses.push(tool_use.clone());
279 }
280
281 let status = if tool_use.is_input_complete {
282 self.tool_use_metadata_by_id
283 .insert(tool_use.id.clone(), metadata);
284
285 PendingToolUseStatus::Idle
286 } else {
287 PendingToolUseStatus::InputStillStreaming
288 };
289
290 let ui_text: Arc<str> = self
291 .tool_ui_label(
292 &tool_use.name,
293 &tool_use.input,
294 tool_use.is_input_complete,
295 cx,
296 )
297 .into();
298
299 self.pending_tool_uses_by_id.insert(
300 tool_use.id.clone(),
301 PendingToolUse {
302 assistant_message_id,
303 id: tool_use.id,
304 name: tool_use.name.clone(),
305 ui_text: ui_text.clone(),
306 input: tool_use.input,
307 status,
308 },
309 );
310
311 ui_text
312 }
313
314 pub fn run_pending_tool(
315 &mut self,
316 tool_use_id: LanguageModelToolUseId,
317 ui_text: SharedString,
318 task: Task<()>,
319 ) {
320 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
321 tool_use.ui_text = ui_text.into();
322 tool_use.status = PendingToolUseStatus::Running {
323 _task: task.shared(),
324 };
325 }
326 }
327
328 pub fn confirm_tool_use(
329 &mut self,
330 tool_use_id: LanguageModelToolUseId,
331 ui_text: impl Into<Arc<str>>,
332 input: serde_json::Value,
333 messages: Arc<Vec<LanguageModelRequestMessage>>,
334 tool: Arc<dyn Tool>,
335 ) {
336 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
337 let ui_text = ui_text.into();
338 tool_use.ui_text = ui_text.clone();
339 let confirmation = Confirmation {
340 tool_use_id,
341 input,
342 messages,
343 tool,
344 ui_text,
345 };
346 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
347 }
348 }
349
350 pub fn insert_tool_output(
351 &mut self,
352 tool_use_id: LanguageModelToolUseId,
353 tool_name: Arc<str>,
354 output: Result<String>,
355 cx: &App,
356 ) -> Option<PendingToolUse> {
357 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
358
359 telemetry::event!(
360 "Agent Tool Finished",
361 model = metadata
362 .as_ref()
363 .map(|metadata| metadata.model.telemetry_id()),
364 model_provider = metadata
365 .as_ref()
366 .map(|metadata| metadata.model.provider_id().to_string()),
367 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
368 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
369 tool_name,
370 success = output.is_ok()
371 );
372
373 match output {
374 Ok(tool_result) => {
375 let model_registry = LanguageModelRegistry::read_global(cx);
376
377 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
378
379 // Protect from clearly large output
380 let tool_output_limit = model_registry
381 .default_model()
382 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
383 .unwrap_or(usize::MAX);
384
385 let tool_result = if tool_result.len() <= tool_output_limit {
386 tool_result
387 } else {
388 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
389
390 format!(
391 "Tool result too long. The first {} bytes:\n\n{}",
392 truncated.len(),
393 truncated
394 )
395 };
396
397 self.tool_results.insert(
398 tool_use_id.clone(),
399 LanguageModelToolResult {
400 tool_use_id: tool_use_id.clone(),
401 tool_name,
402 content: tool_result.into(),
403 is_error: false,
404 },
405 );
406 self.pending_tool_uses_by_id.remove(&tool_use_id)
407 }
408 Err(err) => {
409 self.tool_results.insert(
410 tool_use_id.clone(),
411 LanguageModelToolResult {
412 tool_use_id: tool_use_id.clone(),
413 tool_name,
414 content: err.to_string().into(),
415 is_error: true,
416 },
417 );
418
419 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
420 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
421 }
422
423 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
424 }
425 }
426 }
427
428 pub fn attach_tool_uses(
429 &self,
430 message_id: MessageId,
431 request_message: &mut LanguageModelRequestMessage,
432 ) {
433 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
434 for tool_use in tool_uses {
435 if self.tool_results.contains_key(&tool_use.id) {
436 // Do not send tool uses until they are completed
437 request_message
438 .content
439 .push(MessageContent::ToolUse(tool_use.clone()));
440 } else {
441 log::debug!(
442 "skipped tool use {:?} because it is still pending",
443 tool_use
444 );
445 }
446 }
447 }
448 }
449
450 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
451 self.tool_uses_by_assistant_message
452 .contains_key(&assistant_message_id)
453 }
454
455 pub fn tool_results_message(
456 &self,
457 assistant_message_id: MessageId,
458 ) -> Option<LanguageModelRequestMessage> {
459 let tool_uses = self
460 .tool_uses_by_assistant_message
461 .get(&assistant_message_id)?;
462
463 if tool_uses.is_empty() {
464 return None;
465 }
466
467 let mut request_message = LanguageModelRequestMessage {
468 role: Role::User,
469 content: vec![],
470 cache: false,
471 };
472
473 for tool_use in tool_uses {
474 if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
475 request_message
476 .content
477 .push(MessageContent::ToolResult(LanguageModelToolResult {
478 tool_use_id: tool_use.id.clone(),
479 tool_name: tool_result.tool_name.clone(),
480 is_error: tool_result.is_error,
481 content: if tool_result.content.is_empty() {
482 // Surprisingly, the API fails if we return an empty string here.
483 // It thinks we are sending a tool use without a tool result.
484 "<Tool returned an empty string>".into()
485 } else {
486 tool_result.content.clone()
487 },
488 }));
489 }
490 }
491
492 Some(request_message)
493 }
494}
495
496#[derive(Debug, Clone)]
497pub struct PendingToolUse {
498 pub id: LanguageModelToolUseId,
499 /// The ID of the Assistant message in which the tool use was requested.
500 #[allow(unused)]
501 pub assistant_message_id: MessageId,
502 pub name: Arc<str>,
503 pub ui_text: Arc<str>,
504 pub input: serde_json::Value,
505 pub status: PendingToolUseStatus,
506}
507
508#[derive(Debug, Clone)]
509pub struct Confirmation {
510 pub tool_use_id: LanguageModelToolUseId,
511 pub input: serde_json::Value,
512 pub ui_text: Arc<str>,
513 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
514 pub tool: Arc<dyn Tool>,
515}
516
517#[derive(Debug, Clone)]
518pub enum PendingToolUseStatus {
519 InputStillStreaming,
520 Idle,
521 NeedsConfirmation(Arc<Confirmation>),
522 Running { _task: Shared<Task<()>> },
523 Error(#[allow(unused)] Arc<str>),
524}
525
526impl PendingToolUseStatus {
527 pub fn is_idle(&self) -> bool {
528 matches!(self, PendingToolUseStatus::Idle)
529 }
530
531 pub fn is_error(&self) -> bool {
532 matches!(self, PendingToolUseStatus::Error(_))
533 }
534
535 pub fn needs_confirmation(&self) -> bool {
536 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
537 }
538}
539
540#[derive(Clone)]
541pub struct ToolUseMetadata {
542 pub model: Arc<dyn LanguageModel>,
543 pub thread_id: ThreadId,
544 pub prompt_id: PromptId,
545}