1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14use util::truncate_lines_to_byte_limit;
15
16use crate::thread::{MessageId, PromptId, ThreadId};
17use crate::thread_store::SerializedMessage;
18
19#[derive(Debug)]
20pub struct ToolUse {
21 pub id: LanguageModelToolUseId,
22 pub name: SharedString,
23 pub ui_text: SharedString,
24 pub status: ToolUseStatus,
25 pub input: serde_json::Value,
26 pub icon: ui::IconName,
27 pub needs_confirmation: bool,
28}
29
30pub struct ToolUseState {
31 tools: Entity<ToolWorkingSet>,
32 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
33 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
34 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
35 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
36 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
37 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
38}
39
40impl ToolUseState {
41 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
42 Self {
43 tools,
44 tool_uses_by_assistant_message: HashMap::default(),
45 tool_uses_by_user_message: HashMap::default(),
46 tool_results: HashMap::default(),
47 pending_tool_uses_by_id: HashMap::default(),
48 tool_result_cards: HashMap::default(),
49 tool_use_metadata_by_id: HashMap::default(),
50 }
51 }
52
53 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
54 ///
55 /// Accepts a function to filter the tools that should be used to populate the state.
56 pub fn from_serialized_messages(
57 tools: Entity<ToolWorkingSet>,
58 messages: &[SerializedMessage],
59 mut filter_by_tool_name: impl FnMut(&str) -> bool,
60 ) -> Self {
61 let mut this = Self::new(tools);
62 let mut tool_names_by_id = HashMap::default();
63
64 for message in messages {
65 match message.role {
66 Role::Assistant => {
67 if !message.tool_uses.is_empty() {
68 let tool_uses = message
69 .tool_uses
70 .iter()
71 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
72 .map(|tool_use| LanguageModelToolUse {
73 id: tool_use.id.clone(),
74 name: tool_use.name.clone().into(),
75 input: tool_use.input.clone(),
76 is_input_complete: true,
77 })
78 .collect::<Vec<_>>();
79
80 tool_names_by_id.extend(
81 tool_uses
82 .iter()
83 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
84 );
85
86 this.tool_uses_by_assistant_message
87 .insert(message.id, tool_uses);
88 }
89 }
90 Role::User => {
91 if !message.tool_results.is_empty() {
92 let tool_uses_by_user_message = this
93 .tool_uses_by_user_message
94 .entry(message.id)
95 .or_default();
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 if !(filter_by_tool_name)(tool_use.as_ref()) {
105 continue;
106 }
107
108 tool_uses_by_user_message.push(tool_use_id.clone());
109 this.tool_results.insert(
110 tool_use_id.clone(),
111 LanguageModelToolResult {
112 tool_use_id,
113 tool_name: tool_use.clone(),
114 is_error: tool_result.is_error,
115 content: tool_result.content.clone(),
116 },
117 );
118 }
119 }
120 }
121 Role::System => {}
122 }
123 }
124
125 this
126 }
127
128 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
129 let mut pending_tools = Vec::new();
130 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
131 self.tool_results.insert(
132 tool_use_id.clone(),
133 LanguageModelToolResult {
134 tool_use_id,
135 tool_name: tool_use.name.clone(),
136 content: "Tool canceled by user".into(),
137 is_error: true,
138 },
139 );
140 pending_tools.push(tool_use.clone());
141 }
142 pending_tools
143 }
144
145 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
146 self.pending_tool_uses_by_id.values().collect()
147 }
148
149 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
150 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
151 return Vec::new();
152 };
153
154 let mut tool_uses = Vec::new();
155
156 for tool_use in tool_uses_for_message.iter() {
157 let tool_result = self.tool_results.get(&tool_use.id);
158
159 let status = (|| {
160 if let Some(tool_result) = tool_result {
161 return if tool_result.is_error {
162 ToolUseStatus::Error(tool_result.content.clone().into())
163 } else {
164 ToolUseStatus::Finished(tool_result.content.clone().into())
165 };
166 }
167
168 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
169 match pending_tool_use.status {
170 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
171 PendingToolUseStatus::NeedsConfirmation { .. } => {
172 ToolUseStatus::NeedsConfirmation
173 }
174 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
175 PendingToolUseStatus::Error(ref err) => {
176 ToolUseStatus::Error(err.clone().into())
177 }
178 PendingToolUseStatus::InputStillStreaming => {
179 ToolUseStatus::InputStillStreaming
180 }
181 }
182 } else {
183 ToolUseStatus::Pending
184 }
185 })();
186
187 let (icon, needs_confirmation) =
188 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
189 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
190 } else {
191 (IconName::Cog, false)
192 };
193
194 tool_uses.push(ToolUse {
195 id: tool_use.id.clone(),
196 name: tool_use.name.clone().into(),
197 ui_text: self.tool_ui_label(
198 &tool_use.name,
199 &tool_use.input,
200 tool_use.is_input_complete,
201 cx,
202 ),
203 input: tool_use.input.clone(),
204 status,
205 icon,
206 needs_confirmation,
207 })
208 }
209
210 tool_uses
211 }
212
213 pub fn tool_ui_label(
214 &self,
215 tool_name: &str,
216 input: &serde_json::Value,
217 is_input_complete: bool,
218 cx: &App,
219 ) -> SharedString {
220 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
221 if is_input_complete {
222 tool.ui_text(input).into()
223 } else {
224 tool.still_streaming_ui_text(input).into()
225 }
226 } else {
227 format!("Unknown tool {tool_name:?}").into()
228 }
229 }
230
231 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
232 let empty = Vec::new();
233
234 self.tool_uses_by_user_message
235 .get(&message_id)
236 .unwrap_or(&empty)
237 .iter()
238 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
239 .collect()
240 }
241
242 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
243 self.tool_uses_by_user_message
244 .get(&message_id)
245 .map_or(false, |results| !results.is_empty())
246 }
247
248 pub fn tool_result(
249 &self,
250 tool_use_id: &LanguageModelToolUseId,
251 ) -> Option<&LanguageModelToolResult> {
252 self.tool_results.get(tool_use_id)
253 }
254
255 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
256 self.tool_result_cards.get(tool_use_id)
257 }
258
259 pub fn insert_tool_result_card(
260 &mut self,
261 tool_use_id: LanguageModelToolUseId,
262 card: AnyToolCard,
263 ) {
264 self.tool_result_cards.insert(tool_use_id, card);
265 }
266
267 pub fn request_tool_use(
268 &mut self,
269 assistant_message_id: MessageId,
270 tool_use: LanguageModelToolUse,
271 metadata: ToolUseMetadata,
272 cx: &App,
273 ) -> Arc<str> {
274 let tool_uses = self
275 .tool_uses_by_assistant_message
276 .entry(assistant_message_id)
277 .or_default();
278
279 let mut existing_tool_use_found = false;
280
281 for existing_tool_use in tool_uses.iter_mut() {
282 if existing_tool_use.id == tool_use.id {
283 *existing_tool_use = tool_use.clone();
284 existing_tool_use_found = true;
285 }
286 }
287
288 if !existing_tool_use_found {
289 tool_uses.push(tool_use.clone());
290 }
291
292 let status = if tool_use.is_input_complete {
293 self.tool_use_metadata_by_id
294 .insert(tool_use.id.clone(), metadata);
295
296 // The tool use is being requested by the Assistant, so we want to
297 // attach the tool results to the next user message.
298 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
299 self.tool_uses_by_user_message
300 .entry(next_user_message_id)
301 .or_default()
302 .push(tool_use.id.clone());
303
304 PendingToolUseStatus::Idle
305 } else {
306 PendingToolUseStatus::InputStillStreaming
307 };
308
309 let ui_text: Arc<str> = self
310 .tool_ui_label(
311 &tool_use.name,
312 &tool_use.input,
313 tool_use.is_input_complete,
314 cx,
315 )
316 .into();
317
318 self.pending_tool_uses_by_id.insert(
319 tool_use.id.clone(),
320 PendingToolUse {
321 assistant_message_id,
322 id: tool_use.id,
323 name: tool_use.name.clone(),
324 ui_text: ui_text.clone(),
325 input: tool_use.input,
326 status,
327 },
328 );
329
330 ui_text
331 }
332
333 pub fn run_pending_tool(
334 &mut self,
335 tool_use_id: LanguageModelToolUseId,
336 ui_text: SharedString,
337 task: Task<()>,
338 ) {
339 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
340 tool_use.ui_text = ui_text.into();
341 tool_use.status = PendingToolUseStatus::Running {
342 _task: task.shared(),
343 };
344 }
345 }
346
347 pub fn confirm_tool_use(
348 &mut self,
349 tool_use_id: LanguageModelToolUseId,
350 ui_text: impl Into<Arc<str>>,
351 input: serde_json::Value,
352 messages: Arc<Vec<LanguageModelRequestMessage>>,
353 tool: Arc<dyn Tool>,
354 ) {
355 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
356 let ui_text = ui_text.into();
357 tool_use.ui_text = ui_text.clone();
358 let confirmation = Confirmation {
359 tool_use_id,
360 input,
361 messages,
362 tool,
363 ui_text,
364 };
365 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
366 }
367 }
368
369 pub fn insert_tool_output(
370 &mut self,
371 tool_use_id: LanguageModelToolUseId,
372 tool_name: Arc<str>,
373 output: Result<String>,
374 cx: &App,
375 ) -> Option<PendingToolUse> {
376 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
377
378 telemetry::event!(
379 "Agent Tool Finished",
380 model = metadata
381 .as_ref()
382 .map(|metadata| metadata.model.telemetry_id()),
383 model_provider = metadata
384 .as_ref()
385 .map(|metadata| metadata.model.provider_id().to_string()),
386 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
387 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
388 tool_name,
389 success = output.is_ok()
390 );
391
392 match output {
393 Ok(tool_result) => {
394 let model_registry = LanguageModelRegistry::read_global(cx);
395
396 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
397
398 // Protect from clearly large output
399 let tool_output_limit = model_registry
400 .default_model()
401 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
402 .unwrap_or(usize::MAX);
403
404 let tool_result = if tool_result.len() <= tool_output_limit {
405 tool_result
406 } else {
407 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
408
409 format!(
410 "Tool result too long. The first {} bytes:\n\n{}",
411 truncated.len(),
412 truncated
413 )
414 };
415
416 self.tool_results.insert(
417 tool_use_id.clone(),
418 LanguageModelToolResult {
419 tool_use_id: tool_use_id.clone(),
420 tool_name,
421 content: tool_result.into(),
422 is_error: false,
423 },
424 );
425 self.pending_tool_uses_by_id.remove(&tool_use_id)
426 }
427 Err(err) => {
428 self.tool_results.insert(
429 tool_use_id.clone(),
430 LanguageModelToolResult {
431 tool_use_id: tool_use_id.clone(),
432 tool_name,
433 content: err.to_string().into(),
434 is_error: true,
435 },
436 );
437
438 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
439 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
440 }
441
442 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
443 }
444 }
445 }
446
447 pub fn attach_tool_uses(
448 &self,
449 message_id: MessageId,
450 request_message: &mut LanguageModelRequestMessage,
451 ) {
452 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
453 for tool_use in tool_uses {
454 if self.tool_results.contains_key(&tool_use.id) {
455 // Do not send tool uses until they are completed
456 request_message
457 .content
458 .push(MessageContent::ToolUse(tool_use.clone()));
459 } else {
460 log::debug!(
461 "skipped tool use {:?} because it is still pending",
462 tool_use
463 );
464 }
465 }
466 }
467 }
468
469 pub fn attach_tool_results(
470 &self,
471 message_id: MessageId,
472 request_message: &mut LanguageModelRequestMessage,
473 ) {
474 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
475 for tool_use_id in tool_uses {
476 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
477 request_message.content.push(MessageContent::ToolResult(
478 LanguageModelToolResult {
479 tool_use_id: tool_use_id.clone(),
480 tool_name: tool_result.tool_name.clone(),
481 is_error: tool_result.is_error,
482 content: if tool_result.content.is_empty() {
483 // Surprisingly, the API fails if we return an empty string here.
484 // It thinks we are sending a tool use without a tool result.
485 "<Tool returned an empty string>".into()
486 } else {
487 tool_result.content.clone()
488 },
489 },
490 ));
491 }
492 }
493 }
494 }
495}
496
497#[derive(Debug, Clone)]
498pub struct PendingToolUse {
499 pub id: LanguageModelToolUseId,
500 /// The ID of the Assistant message in which the tool use was requested.
501 #[allow(unused)]
502 pub assistant_message_id: MessageId,
503 pub name: Arc<str>,
504 pub ui_text: Arc<str>,
505 pub input: serde_json::Value,
506 pub status: PendingToolUseStatus,
507}
508
509#[derive(Debug, Clone)]
510pub struct Confirmation {
511 pub tool_use_id: LanguageModelToolUseId,
512 pub input: serde_json::Value,
513 pub ui_text: Arc<str>,
514 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
515 pub tool: Arc<dyn Tool>,
516}
517
518#[derive(Debug, Clone)]
519pub enum PendingToolUseStatus {
520 InputStillStreaming,
521 Idle,
522 NeedsConfirmation(Arc<Confirmation>),
523 Running { _task: Shared<Task<()>> },
524 Error(#[allow(unused)] Arc<str>),
525}
526
527impl PendingToolUseStatus {
528 pub fn is_idle(&self) -> bool {
529 matches!(self, PendingToolUseStatus::Idle)
530 }
531
532 pub fn is_error(&self) -> bool {
533 matches!(self, PendingToolUseStatus::Error(_))
534 }
535
536 pub fn needs_confirmation(&self) -> bool {
537 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
538 }
539}
540
541#[derive(Clone)]
542pub struct ToolUseMetadata {
543 pub model: Arc<dyn LanguageModel>,
544 pub thread_id: ThreadId,
545 pub prompt_id: PromptId,
546}