1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub const USING_TOOL_MARKER: &str = "<using_tool>";
31
32pub struct ToolUseState {
33 tools: Entity<ToolWorkingSet>,
34 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
35 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
36 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
37 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
38 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
39 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
40}
41
42impl ToolUseState {
43 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
44 Self {
45 tools,
46 tool_uses_by_assistant_message: HashMap::default(),
47 tool_uses_by_user_message: HashMap::default(),
48 tool_results: HashMap::default(),
49 pending_tool_uses_by_id: HashMap::default(),
50 tool_result_cards: HashMap::default(),
51 tool_use_metadata_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 pub fn from_serialized_messages(
59 tools: Entity<ToolWorkingSet>,
60 messages: &[SerializedMessage],
61 mut filter_by_tool_name: impl FnMut(&str) -> bool,
62 ) -> Self {
63 let mut this = Self::new(tools);
64 let mut tool_names_by_id = HashMap::default();
65
66 for message in messages {
67 match message.role {
68 Role::Assistant => {
69 if !message.tool_uses.is_empty() {
70 let tool_uses = message
71 .tool_uses
72 .iter()
73 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
74 .map(|tool_use| LanguageModelToolUse {
75 id: tool_use.id.clone(),
76 name: tool_use.name.clone().into(),
77 input: tool_use.input.clone(),
78 is_input_complete: true,
79 })
80 .collect::<Vec<_>>();
81
82 tool_names_by_id.extend(
83 tool_uses
84 .iter()
85 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
86 );
87
88 this.tool_uses_by_assistant_message
89 .insert(message.id, tool_uses);
90 }
91 }
92 Role::User => {
93 if !message.tool_results.is_empty() {
94 let tool_uses_by_user_message = this
95 .tool_uses_by_user_message
96 .entry(message.id)
97 .or_default();
98
99 for tool_result in &message.tool_results {
100 let tool_use_id = tool_result.tool_use_id.clone();
101 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
102 log::warn!("no tool name found for tool use: {tool_use_id:?}");
103 continue;
104 };
105
106 if !(filter_by_tool_name)(tool_use.as_ref()) {
107 continue;
108 }
109
110 tool_uses_by_user_message.push(tool_use_id.clone());
111 this.tool_results.insert(
112 tool_use_id.clone(),
113 LanguageModelToolResult {
114 tool_use_id,
115 tool_name: tool_use.clone(),
116 is_error: tool_result.is_error,
117 content: tool_result.content.clone(),
118 },
119 );
120 }
121 }
122 }
123 Role::System => {}
124 }
125 }
126
127 this
128 }
129
130 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
131 let mut pending_tools = Vec::new();
132 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
133 self.tool_results.insert(
134 tool_use_id.clone(),
135 LanguageModelToolResult {
136 tool_use_id,
137 tool_name: tool_use.name.clone(),
138 content: "Tool canceled by user".into(),
139 is_error: true,
140 },
141 );
142 pending_tools.push(tool_use.clone());
143 }
144 pending_tools
145 }
146
147 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
148 self.pending_tool_uses_by_id.values().collect()
149 }
150
151 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
152 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
153 return Vec::new();
154 };
155
156 let mut tool_uses = Vec::new();
157
158 for tool_use in tool_uses_for_message.iter() {
159 let tool_result = self.tool_results.get(&tool_use.id);
160
161 let status = (|| {
162 if let Some(tool_result) = tool_result {
163 return if tool_result.is_error {
164 ToolUseStatus::Error(tool_result.content.clone().into())
165 } else {
166 ToolUseStatus::Finished(tool_result.content.clone().into())
167 };
168 }
169
170 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
171 match pending_tool_use.status {
172 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
173 PendingToolUseStatus::NeedsConfirmation { .. } => {
174 ToolUseStatus::NeedsConfirmation
175 }
176 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
177 PendingToolUseStatus::Error(ref err) => {
178 ToolUseStatus::Error(err.clone().into())
179 }
180 PendingToolUseStatus::InputStillStreaming => {
181 ToolUseStatus::InputStillStreaming
182 }
183 }
184 } else {
185 ToolUseStatus::Pending
186 }
187 })();
188
189 let (icon, needs_confirmation) =
190 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
191 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
192 } else {
193 (IconName::Cog, false)
194 };
195
196 tool_uses.push(ToolUse {
197 id: tool_use.id.clone(),
198 name: tool_use.name.clone().into(),
199 ui_text: self.tool_ui_label(
200 &tool_use.name,
201 &tool_use.input,
202 tool_use.is_input_complete,
203 cx,
204 ),
205 input: tool_use.input.clone(),
206 status,
207 icon,
208 needs_confirmation,
209 })
210 }
211
212 tool_uses
213 }
214
215 pub fn tool_ui_label(
216 &self,
217 tool_name: &str,
218 input: &serde_json::Value,
219 is_input_complete: bool,
220 cx: &App,
221 ) -> SharedString {
222 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
223 if is_input_complete {
224 tool.ui_text(input).into()
225 } else {
226 tool.still_streaming_ui_text(input).into()
227 }
228 } else {
229 format!("Unknown tool {tool_name:?}").into()
230 }
231 }
232
233 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
234 let empty = Vec::new();
235
236 self.tool_uses_by_user_message
237 .get(&message_id)
238 .unwrap_or(&empty)
239 .iter()
240 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
241 .collect()
242 }
243
244 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
245 self.tool_uses_by_user_message
246 .get(&message_id)
247 .map_or(false, |results| !results.is_empty())
248 }
249
250 pub fn tool_result(
251 &self,
252 tool_use_id: &LanguageModelToolUseId,
253 ) -> Option<&LanguageModelToolResult> {
254 self.tool_results.get(tool_use_id)
255 }
256
257 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
258 self.tool_result_cards.get(tool_use_id)
259 }
260
261 pub fn insert_tool_result_card(
262 &mut self,
263 tool_use_id: LanguageModelToolUseId,
264 card: AnyToolCard,
265 ) {
266 self.tool_result_cards.insert(tool_use_id, card);
267 }
268
269 pub fn request_tool_use(
270 &mut self,
271 assistant_message_id: MessageId,
272 tool_use: LanguageModelToolUse,
273 metadata: ToolUseMetadata,
274 cx: &App,
275 ) -> Arc<str> {
276 let tool_uses = self
277 .tool_uses_by_assistant_message
278 .entry(assistant_message_id)
279 .or_default();
280
281 let mut existing_tool_use_found = false;
282
283 for existing_tool_use in tool_uses.iter_mut() {
284 if existing_tool_use.id == tool_use.id {
285 *existing_tool_use = tool_use.clone();
286 existing_tool_use_found = true;
287 }
288 }
289
290 if !existing_tool_use_found {
291 tool_uses.push(tool_use.clone());
292 }
293
294 let status = if tool_use.is_input_complete {
295 self.tool_use_metadata_by_id
296 .insert(tool_use.id.clone(), metadata);
297
298 // The tool use is being requested by the Assistant, so we want to
299 // attach the tool results to the next user message.
300 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
301 self.tool_uses_by_user_message
302 .entry(next_user_message_id)
303 .or_default()
304 .push(tool_use.id.clone());
305
306 PendingToolUseStatus::Idle
307 } else {
308 PendingToolUseStatus::InputStillStreaming
309 };
310
311 let ui_text: Arc<str> = self
312 .tool_ui_label(
313 &tool_use.name,
314 &tool_use.input,
315 tool_use.is_input_complete,
316 cx,
317 )
318 .into();
319
320 self.pending_tool_uses_by_id.insert(
321 tool_use.id.clone(),
322 PendingToolUse {
323 assistant_message_id,
324 id: tool_use.id,
325 name: tool_use.name.clone(),
326 ui_text: ui_text.clone(),
327 input: tool_use.input,
328 status,
329 },
330 );
331
332 ui_text
333 }
334
335 pub fn run_pending_tool(
336 &mut self,
337 tool_use_id: LanguageModelToolUseId,
338 ui_text: SharedString,
339 task: Task<()>,
340 ) {
341 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
342 tool_use.ui_text = ui_text.into();
343 tool_use.status = PendingToolUseStatus::Running {
344 _task: task.shared(),
345 };
346 }
347 }
348
349 pub fn confirm_tool_use(
350 &mut self,
351 tool_use_id: LanguageModelToolUseId,
352 ui_text: impl Into<Arc<str>>,
353 input: serde_json::Value,
354 messages: Arc<Vec<LanguageModelRequestMessage>>,
355 tool: Arc<dyn Tool>,
356 ) {
357 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
358 let ui_text = ui_text.into();
359 tool_use.ui_text = ui_text.clone();
360 let confirmation = Confirmation {
361 tool_use_id,
362 input,
363 messages,
364 tool,
365 ui_text,
366 };
367 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
368 }
369 }
370
371 pub fn insert_tool_output(
372 &mut self,
373 tool_use_id: LanguageModelToolUseId,
374 tool_name: Arc<str>,
375 output: Result<String>,
376 cx: &App,
377 ) -> Option<PendingToolUse> {
378 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
379
380 telemetry::event!(
381 "Agent Tool Finished",
382 model = metadata
383 .as_ref()
384 .map(|metadata| metadata.model.telemetry_id()),
385 model_provider = metadata
386 .as_ref()
387 .map(|metadata| metadata.model.provider_id().to_string()),
388 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
389 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
390 tool_name,
391 success = output.is_ok()
392 );
393
394 match output {
395 Ok(tool_result) => {
396 let model_registry = LanguageModelRegistry::read_global(cx);
397
398 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
399
400 // Protect from clearly large output
401 let tool_output_limit = model_registry
402 .default_model()
403 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
404 .unwrap_or(usize::MAX);
405
406 let tool_result = if tool_result.len() <= tool_output_limit {
407 tool_result
408 } else {
409 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
410
411 format!(
412 "Tool result too long. The first {} bytes:\n\n{}",
413 truncated.len(),
414 truncated
415 )
416 };
417
418 self.tool_results.insert(
419 tool_use_id.clone(),
420 LanguageModelToolResult {
421 tool_use_id: tool_use_id.clone(),
422 tool_name,
423 content: tool_result.into(),
424 is_error: false,
425 },
426 );
427 self.pending_tool_uses_by_id.remove(&tool_use_id)
428 }
429 Err(err) => {
430 self.tool_results.insert(
431 tool_use_id.clone(),
432 LanguageModelToolResult {
433 tool_use_id: tool_use_id.clone(),
434 tool_name,
435 content: err.to_string().into(),
436 is_error: true,
437 },
438 );
439
440 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
441 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
442 }
443
444 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
445 }
446 }
447 }
448
449 pub fn attach_tool_uses(
450 &self,
451 message_id: MessageId,
452 request_message: &mut LanguageModelRequestMessage,
453 ) {
454 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
455 let mut found_tool_use = false;
456
457 for tool_use in tool_uses {
458 if self.tool_results.contains_key(&tool_use.id) {
459 if !found_tool_use {
460 // The API fails if a message contains a tool use without any (non-whitespace) text around it
461 match request_message.content.last_mut() {
462 Some(MessageContent::Text(txt)) => {
463 if txt.is_empty() {
464 txt.push_str(USING_TOOL_MARKER);
465 }
466 }
467 None | Some(_) => {
468 request_message
469 .content
470 .push(MessageContent::Text(USING_TOOL_MARKER.into()));
471 }
472 };
473 }
474
475 found_tool_use = true;
476
477 // Do not send tool uses until they are completed
478 request_message
479 .content
480 .push(MessageContent::ToolUse(tool_use.clone()));
481 } else {
482 log::debug!(
483 "skipped tool use {:?} because it is still pending",
484 tool_use
485 );
486 }
487 }
488 }
489 }
490
491 pub fn attach_tool_results(
492 &self,
493 message_id: MessageId,
494 request_message: &mut LanguageModelRequestMessage,
495 ) {
496 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
497 for tool_use_id in tool_uses {
498 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
499 request_message.content.push(MessageContent::ToolResult(
500 LanguageModelToolResult {
501 tool_use_id: tool_use_id.clone(),
502 tool_name: tool_result.tool_name.clone(),
503 is_error: tool_result.is_error,
504 content: if tool_result.content.is_empty() {
505 // Surprisingly, the API fails if we return an empty string here.
506 // It thinks we are sending a tool use without a tool result.
507 "<Tool returned an empty string>".into()
508 } else {
509 tool_result.content.clone()
510 },
511 },
512 ));
513 }
514 }
515 }
516 }
517}
518
519#[derive(Debug, Clone)]
520pub struct PendingToolUse {
521 pub id: LanguageModelToolUseId,
522 /// The ID of the Assistant message in which the tool use was requested.
523 #[allow(unused)]
524 pub assistant_message_id: MessageId,
525 pub name: Arc<str>,
526 pub ui_text: Arc<str>,
527 pub input: serde_json::Value,
528 pub status: PendingToolUseStatus,
529}
530
531#[derive(Debug, Clone)]
532pub struct Confirmation {
533 pub tool_use_id: LanguageModelToolUseId,
534 pub input: serde_json::Value,
535 pub ui_text: Arc<str>,
536 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
537 pub tool: Arc<dyn Tool>,
538}
539
540#[derive(Debug, Clone)]
541pub enum PendingToolUseStatus {
542 InputStillStreaming,
543 Idle,
544 NeedsConfirmation(Arc<Confirmation>),
545 Running { _task: Shared<Task<()>> },
546 Error(#[allow(unused)] Arc<str>),
547}
548
549impl PendingToolUseStatus {
550 pub fn is_idle(&self) -> bool {
551 matches!(self, PendingToolUseStatus::Idle)
552 }
553
554 pub fn is_error(&self) -> bool {
555 matches!(self, PendingToolUseStatus::Error(_))
556 }
557
558 pub fn needs_confirmation(&self) -> bool {
559 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
560 }
561}
562
563#[derive(Clone)]
564pub struct ToolUseMetadata {
565 pub model: Arc<dyn LanguageModel>,
566 pub thread_id: ThreadId,
567 pub prompt_id: PromptId,
568}