1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, Entity, SharedString, Task};
9use language_model::{
10 ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
11 LanguageModelToolUse, LanguageModelToolUseId, Role,
12};
13use project::Project;
14use ui::{IconName, Window};
15use util::truncate_lines_to_byte_limit;
16
17use crate::thread::{MessageId, PromptId, ThreadId};
18use crate::thread_store::SerializedMessage;
19
20#[derive(Debug)]
21pub struct ToolUse {
22 pub id: LanguageModelToolUseId,
23 pub name: SharedString,
24 pub ui_text: SharedString,
25 pub status: ToolUseStatus,
26 pub input: serde_json::Value,
27 pub icon: ui::IconName,
28 pub needs_confirmation: bool,
29}
30
31pub struct ToolUseState {
32 tools: Entity<ToolWorkingSet>,
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
35 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
36 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
37 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
38}
39
40impl ToolUseState {
41 pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
42 Self {
43 tools,
44 tool_uses_by_assistant_message: HashMap::default(),
45 tool_results: HashMap::default(),
46 pending_tool_uses_by_id: HashMap::default(),
47 tool_result_cards: HashMap::default(),
48 tool_use_metadata_by_id: HashMap::default(),
49 }
50 }
51
52 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
53 ///
54 /// Accepts a function to filter the tools that should be used to populate the state.
55 pub fn from_serialized_messages(
56 tools: Entity<ToolWorkingSet>,
57 messages: &[SerializedMessage],
58 project: Entity<Project>,
59 window: &mut Window,
60 cx: &mut App,
61 ) -> Self {
62 let mut this = Self::new(tools);
63 let mut tool_names_by_id = HashMap::default();
64
65 for message in messages {
66 match message.role {
67 Role::Assistant => {
68 if !message.tool_uses.is_empty() {
69 let tool_uses = message
70 .tool_uses
71 .iter()
72 .map(|tool_use| LanguageModelToolUse {
73 id: tool_use.id.clone(),
74 name: tool_use.name.clone().into(),
75 raw_input: tool_use.input.to_string(),
76 input: tool_use.input.clone(),
77 is_input_complete: true,
78 })
79 .collect::<Vec<_>>();
80
81 tool_names_by_id.extend(
82 tool_uses
83 .iter()
84 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
85 );
86
87 this.tool_uses_by_assistant_message
88 .insert(message.id, tool_uses);
89
90 for tool_result in &message.tool_results {
91 let tool_use_id = tool_result.tool_use_id.clone();
92 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
93 log::warn!("no tool name found for tool use: {tool_use_id:?}");
94 continue;
95 };
96
97 this.tool_results.insert(
98 tool_use_id.clone(),
99 LanguageModelToolResult {
100 tool_use_id: tool_use_id.clone(),
101 tool_name: tool_use.clone(),
102 is_error: tool_result.is_error,
103 content: tool_result.content.clone(),
104 output: tool_result.output.clone(),
105 },
106 );
107
108 if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
109 if let Some(output) = tool_result.output.clone() {
110 if let Some(card) =
111 tool.deserialize_card(output, project.clone(), window, cx)
112 {
113 this.tool_result_cards.insert(tool_use_id, card);
114 }
115 }
116 }
117 }
118 }
119 }
120 Role::System | Role::User => {}
121 }
122 }
123
124 this
125 }
126
127 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128 let mut cancelled_tool_uses = Vec::new();
129 self.pending_tool_uses_by_id
130 .retain(|tool_use_id, tool_use| {
131 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
132 return true;
133 }
134
135 let content = "Tool canceled by user".into();
136 self.tool_results.insert(
137 tool_use_id.clone(),
138 LanguageModelToolResult {
139 tool_use_id: tool_use_id.clone(),
140 tool_name: tool_use.name.clone(),
141 content,
142 output: None,
143 is_error: true,
144 },
145 );
146 cancelled_tool_uses.push(tool_use.clone());
147 false
148 });
149 cancelled_tool_uses
150 }
151
152 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
153 self.pending_tool_uses_by_id.values().collect()
154 }
155
156 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
157 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
158 return Vec::new();
159 };
160
161 let mut tool_uses = Vec::new();
162
163 for tool_use in tool_uses_for_message.iter() {
164 let tool_result = self.tool_results.get(&tool_use.id);
165
166 let status = (|| {
167 if let Some(tool_result) = tool_result {
168 return if tool_result.is_error {
169 ToolUseStatus::Error(tool_result.content.clone().into())
170 } else {
171 ToolUseStatus::Finished(tool_result.content.clone().into())
172 };
173 }
174
175 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
176 match pending_tool_use.status {
177 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
178 PendingToolUseStatus::NeedsConfirmation { .. } => {
179 ToolUseStatus::NeedsConfirmation
180 }
181 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
182 PendingToolUseStatus::Error(ref err) => {
183 ToolUseStatus::Error(err.clone().into())
184 }
185 PendingToolUseStatus::InputStillStreaming => {
186 ToolUseStatus::InputStillStreaming
187 }
188 }
189 } else {
190 ToolUseStatus::Pending
191 }
192 })();
193
194 let (icon, needs_confirmation) =
195 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
196 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
197 } else {
198 (IconName::Cog, false)
199 };
200
201 tool_uses.push(ToolUse {
202 id: tool_use.id.clone(),
203 name: tool_use.name.clone().into(),
204 ui_text: self.tool_ui_label(
205 &tool_use.name,
206 &tool_use.input,
207 tool_use.is_input_complete,
208 cx,
209 ),
210 input: tool_use.input.clone(),
211 status,
212 icon,
213 needs_confirmation,
214 })
215 }
216
217 tool_uses
218 }
219
220 pub fn tool_ui_label(
221 &self,
222 tool_name: &str,
223 input: &serde_json::Value,
224 is_input_complete: bool,
225 cx: &App,
226 ) -> SharedString {
227 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
228 if is_input_complete {
229 tool.ui_text(input).into()
230 } else {
231 tool.still_streaming_ui_text(input).into()
232 }
233 } else {
234 format!("Unknown tool {tool_name:?}").into()
235 }
236 }
237
238 pub fn tool_results_for_message(
239 &self,
240 assistant_message_id: MessageId,
241 ) -> Vec<&LanguageModelToolResult> {
242 let Some(tool_uses) = self
243 .tool_uses_by_assistant_message
244 .get(&assistant_message_id)
245 else {
246 return Vec::new();
247 };
248
249 tool_uses
250 .iter()
251 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
252 .collect()
253 }
254
255 pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
256 self.tool_uses_by_assistant_message
257 .get(&assistant_message_id)
258 .map_or(false, |results| !results.is_empty())
259 }
260
261 pub fn tool_result(
262 &self,
263 tool_use_id: &LanguageModelToolUseId,
264 ) -> Option<&LanguageModelToolResult> {
265 self.tool_results.get(tool_use_id)
266 }
267
268 pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
269 self.tool_result_cards.get(tool_use_id)
270 }
271
272 pub fn insert_tool_result_card(
273 &mut self,
274 tool_use_id: LanguageModelToolUseId,
275 card: AnyToolCard,
276 ) {
277 self.tool_result_cards.insert(tool_use_id, card);
278 }
279
280 pub fn request_tool_use(
281 &mut self,
282 assistant_message_id: MessageId,
283 tool_use: LanguageModelToolUse,
284 metadata: ToolUseMetadata,
285 cx: &App,
286 ) -> Arc<str> {
287 let tool_uses = self
288 .tool_uses_by_assistant_message
289 .entry(assistant_message_id)
290 .or_default();
291
292 let mut existing_tool_use_found = false;
293
294 for existing_tool_use in tool_uses.iter_mut() {
295 if existing_tool_use.id == tool_use.id {
296 *existing_tool_use = tool_use.clone();
297 existing_tool_use_found = true;
298 }
299 }
300
301 if !existing_tool_use_found {
302 tool_uses.push(tool_use.clone());
303 }
304
305 let status = if tool_use.is_input_complete {
306 self.tool_use_metadata_by_id
307 .insert(tool_use.id.clone(), metadata);
308
309 PendingToolUseStatus::Idle
310 } else {
311 PendingToolUseStatus::InputStillStreaming
312 };
313
314 let ui_text: Arc<str> = self
315 .tool_ui_label(
316 &tool_use.name,
317 &tool_use.input,
318 tool_use.is_input_complete,
319 cx,
320 )
321 .into();
322
323 self.pending_tool_uses_by_id.insert(
324 tool_use.id.clone(),
325 PendingToolUse {
326 assistant_message_id,
327 id: tool_use.id,
328 name: tool_use.name.clone(),
329 ui_text: ui_text.clone(),
330 input: tool_use.input,
331 status,
332 },
333 );
334
335 ui_text
336 }
337
338 pub fn run_pending_tool(
339 &mut self,
340 tool_use_id: LanguageModelToolUseId,
341 ui_text: SharedString,
342 task: Task<()>,
343 ) {
344 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
345 tool_use.ui_text = ui_text.into();
346 tool_use.status = PendingToolUseStatus::Running {
347 _task: task.shared(),
348 };
349 }
350 }
351
352 pub fn confirm_tool_use(
353 &mut self,
354 tool_use_id: LanguageModelToolUseId,
355 ui_text: impl Into<Arc<str>>,
356 input: serde_json::Value,
357 request: Arc<LanguageModelRequest>,
358 tool: Arc<dyn Tool>,
359 ) {
360 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
361 let ui_text = ui_text.into();
362 tool_use.ui_text = ui_text.clone();
363 let confirmation = Confirmation {
364 tool_use_id,
365 input,
366 request,
367 tool,
368 ui_text,
369 };
370 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
371 }
372 }
373
374 pub fn insert_tool_output(
375 &mut self,
376 tool_use_id: LanguageModelToolUseId,
377 tool_name: Arc<str>,
378 output: Result<ToolResultOutput>,
379 configured_model: Option<&ConfiguredModel>,
380 ) -> Option<PendingToolUse> {
381 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
382
383 telemetry::event!(
384 "Agent Tool Finished",
385 model = metadata
386 .as_ref()
387 .map(|metadata| metadata.model.telemetry_id()),
388 model_provider = metadata
389 .as_ref()
390 .map(|metadata| metadata.model.provider_id().to_string()),
391 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
392 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
393 tool_name,
394 success = output.is_ok()
395 );
396
397 match output {
398 Ok(output) => {
399 let tool_result = output.content;
400 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
401
402 // Protect from clearly large output
403 let tool_output_limit = configured_model
404 .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
405 .unwrap_or(usize::MAX);
406
407 let tool_result = if tool_result.len() <= tool_output_limit {
408 tool_result
409 } else {
410 let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
411
412 format!(
413 "Tool result too long. The first {} bytes:\n\n{}",
414 truncated.len(),
415 truncated
416 )
417 };
418
419 self.tool_results.insert(
420 tool_use_id.clone(),
421 LanguageModelToolResult {
422 tool_use_id: tool_use_id.clone(),
423 tool_name,
424 content: tool_result.into(),
425 is_error: false,
426 output: output.output,
427 },
428 );
429 self.pending_tool_uses_by_id.remove(&tool_use_id)
430 }
431 Err(err) => {
432 self.tool_results.insert(
433 tool_use_id.clone(),
434 LanguageModelToolResult {
435 tool_use_id: tool_use_id.clone(),
436 tool_name,
437 content: err.to_string().into(),
438 is_error: true,
439 output: None,
440 },
441 );
442
443 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
444 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
445 }
446
447 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
448 }
449 }
450 }
451
452 pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
453 self.tool_uses_by_assistant_message
454 .contains_key(&assistant_message_id)
455 }
456
457 pub fn tool_results(
458 &self,
459 assistant_message_id: MessageId,
460 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
461 self.tool_uses_by_assistant_message
462 .get(&assistant_message_id)
463 .into_iter()
464 .flatten()
465 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
466 }
467}
468
469#[derive(Debug, Clone)]
470pub struct PendingToolUse {
471 pub id: LanguageModelToolUseId,
472 /// The ID of the Assistant message in which the tool use was requested.
473 #[allow(unused)]
474 pub assistant_message_id: MessageId,
475 pub name: Arc<str>,
476 pub ui_text: Arc<str>,
477 pub input: serde_json::Value,
478 pub status: PendingToolUseStatus,
479}
480
481#[derive(Debug, Clone)]
482pub struct Confirmation {
483 pub tool_use_id: LanguageModelToolUseId,
484 pub input: serde_json::Value,
485 pub ui_text: Arc<str>,
486 pub request: Arc<LanguageModelRequest>,
487 pub tool: Arc<dyn Tool>,
488}
489
490#[derive(Debug, Clone)]
491pub enum PendingToolUseStatus {
492 InputStillStreaming,
493 Idle,
494 NeedsConfirmation(Arc<Confirmation>),
495 Running { _task: Shared<Task<()>> },
496 Error(#[allow(unused)] Arc<str>),
497}
498
499impl PendingToolUseStatus {
500 pub fn is_idle(&self) -> bool {
501 matches!(self, PendingToolUseStatus::Idle)
502 }
503
504 pub fn is_error(&self) -> bool {
505 matches!(self, PendingToolUseStatus::Error(_))
506 }
507
508 pub fn needs_confirmation(&self) -> bool {
509 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
510 }
511}
512
513#[derive(Clone)]
514pub struct ToolUseMetadata {
515 pub model: Arc<dyn LanguageModel>,
516 pub thread_id: ThreadId,
517 pub prompt_id: PromptId,
518}