1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25 pub icon: ui::IconName,
26 pub needs_confirmation: bool,
27}
28
29#[derive(Debug, Clone)]
30pub enum ToolUseStatus {
31 NeedsConfirmation,
32 Pending,
33 Running,
34 Finished(SharedString),
35 Error(SharedString),
36}
37
38impl ToolUseStatus {
39 pub fn text(&self) -> SharedString {
40 match self {
41 ToolUseStatus::NeedsConfirmation => "".into(),
42 ToolUseStatus::Pending => "".into(),
43 ToolUseStatus::Running => "".into(),
44 ToolUseStatus::Finished(out) => out.clone(),
45 ToolUseStatus::Error(out) => out.clone(),
46 }
47 }
48}
49
50pub struct ToolUseState {
51 tools: Arc<ToolWorkingSet>,
52 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
53 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
54 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
55 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
56}
57
58pub const USING_TOOL_MARKER: &str = "<using_tool>";
59
60impl ToolUseState {
61 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
62 Self {
63 tools,
64 tool_uses_by_assistant_message: HashMap::default(),
65 tool_uses_by_user_message: HashMap::default(),
66 tool_results: HashMap::default(),
67 pending_tool_uses_by_id: HashMap::default(),
68 }
69 }
70
71 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
72 ///
73 /// Accepts a function to filter the tools that should be used to populate the state.
74 pub fn from_serialized_messages(
75 tools: Arc<ToolWorkingSet>,
76 messages: &[SerializedMessage],
77 mut filter_by_tool_name: impl FnMut(&str) -> bool,
78 ) -> Self {
79 let mut this = Self::new(tools);
80 let mut tool_names_by_id = HashMap::default();
81
82 for message in messages {
83 match message.role {
84 Role::Assistant => {
85 if !message.tool_uses.is_empty() {
86 let tool_uses = message
87 .tool_uses
88 .iter()
89 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
90 .map(|tool_use| LanguageModelToolUse {
91 id: tool_use.id.clone(),
92 name: tool_use.name.clone().into(),
93 input: tool_use.input.clone(),
94 })
95 .collect::<Vec<_>>();
96
97 tool_names_by_id.extend(
98 tool_uses
99 .iter()
100 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
101 );
102
103 this.tool_uses_by_assistant_message
104 .insert(message.id, tool_uses);
105 }
106 }
107 Role::User => {
108 if !message.tool_results.is_empty() {
109 let tool_uses_by_user_message = this
110 .tool_uses_by_user_message
111 .entry(message.id)
112 .or_default();
113
114 for tool_result in &message.tool_results {
115 let tool_use_id = tool_result.tool_use_id.clone();
116 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
117 log::warn!("no tool name found for tool use: {tool_use_id:?}");
118 continue;
119 };
120
121 if !(filter_by_tool_name)(tool_use.as_ref()) {
122 continue;
123 }
124
125 tool_uses_by_user_message.push(tool_use_id.clone());
126 this.tool_results.insert(
127 tool_use_id.clone(),
128 LanguageModelToolResult {
129 tool_use_id,
130 tool_name: tool_use.clone(),
131 is_error: tool_result.is_error,
132 content: tool_result.content.clone(),
133 },
134 );
135 }
136 }
137 }
138 Role::System => {}
139 }
140 }
141
142 this
143 }
144
145 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
146 let mut pending_tools = Vec::new();
147 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
148 self.tool_results.insert(
149 tool_use_id.clone(),
150 LanguageModelToolResult {
151 tool_use_id,
152 tool_name: tool_use.name.clone(),
153 content: "Tool canceled by user".into(),
154 is_error: true,
155 },
156 );
157 pending_tools.push(tool_use.clone());
158 }
159 pending_tools
160 }
161
162 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
163 self.pending_tool_uses_by_id.values().collect()
164 }
165
166 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
167 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
168 return Vec::new();
169 };
170
171 let mut tool_uses = Vec::new();
172
173 for tool_use in tool_uses_for_message.iter() {
174 let tool_result = self.tool_results.get(&tool_use.id);
175
176 let status = (|| {
177 if let Some(tool_result) = tool_result {
178 return if tool_result.is_error {
179 ToolUseStatus::Error(tool_result.content.clone().into())
180 } else {
181 ToolUseStatus::Finished(tool_result.content.clone().into())
182 };
183 }
184
185 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
186 match pending_tool_use.status {
187 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
188 PendingToolUseStatus::NeedsConfirmation { .. } => {
189 ToolUseStatus::NeedsConfirmation
190 }
191 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
192 PendingToolUseStatus::Error(ref err) => {
193 ToolUseStatus::Error(err.clone().into())
194 }
195 }
196 } else {
197 ToolUseStatus::Pending
198 }
199 })();
200
201 let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
202 {
203 (tool.icon(), tool.needs_confirmation())
204 } else {
205 (IconName::Cog, false)
206 };
207
208 tool_uses.push(ToolUse {
209 id: tool_use.id.clone(),
210 name: tool_use.name.clone().into(),
211 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
212 input: tool_use.input.clone(),
213 status,
214 icon,
215 needs_confirmation,
216 })
217 }
218
219 tool_uses
220 }
221
222 pub fn tool_ui_label(
223 &self,
224 tool_name: &str,
225 input: &serde_json::Value,
226 cx: &App,
227 ) -> SharedString {
228 if let Some(tool) = self.tools.tool(tool_name, cx) {
229 tool.ui_text(input).into()
230 } else {
231 format!("Unknown tool {tool_name:?}").into()
232 }
233 }
234
235 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
236 let empty = Vec::new();
237
238 self.tool_uses_by_user_message
239 .get(&message_id)
240 .unwrap_or(&empty)
241 .iter()
242 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
243 .collect()
244 }
245
246 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
247 self.tool_uses_by_user_message
248 .get(&message_id)
249 .map_or(false, |results| !results.is_empty())
250 }
251
252 pub fn tool_result(
253 &self,
254 tool_use_id: &LanguageModelToolUseId,
255 ) -> Option<&LanguageModelToolResult> {
256 self.tool_results.get(tool_use_id)
257 }
258
259 pub fn request_tool_use(
260 &mut self,
261 assistant_message_id: MessageId,
262 tool_use: LanguageModelToolUse,
263 cx: &App,
264 ) {
265 self.tool_uses_by_assistant_message
266 .entry(assistant_message_id)
267 .or_default()
268 .push(tool_use.clone());
269
270 // The tool use is being requested by the Assistant, so we want to
271 // attach the tool results to the next user message.
272 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
273 self.tool_uses_by_user_message
274 .entry(next_user_message_id)
275 .or_default()
276 .push(tool_use.id.clone());
277
278 self.pending_tool_uses_by_id.insert(
279 tool_use.id.clone(),
280 PendingToolUse {
281 assistant_message_id,
282 id: tool_use.id,
283 name: tool_use.name.clone(),
284 ui_text: self
285 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
286 .into(),
287 input: tool_use.input,
288 status: PendingToolUseStatus::Idle,
289 },
290 );
291 }
292
293 pub fn run_pending_tool(
294 &mut self,
295 tool_use_id: LanguageModelToolUseId,
296 ui_text: SharedString,
297 task: Task<()>,
298 ) {
299 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
300 tool_use.ui_text = ui_text.into();
301 tool_use.status = PendingToolUseStatus::Running {
302 _task: task.shared(),
303 };
304 }
305 }
306
307 pub fn confirm_tool_use(
308 &mut self,
309 tool_use_id: LanguageModelToolUseId,
310 ui_text: impl Into<Arc<str>>,
311 input: serde_json::Value,
312 messages: Arc<Vec<LanguageModelRequestMessage>>,
313 tool: Arc<dyn Tool>,
314 ) {
315 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
316 let ui_text = ui_text.into();
317 tool_use.ui_text = ui_text.clone();
318 let confirmation = Confirmation {
319 tool_use_id,
320 input,
321 messages,
322 tool,
323 ui_text,
324 };
325 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
326 }
327 }
328
329 pub fn insert_tool_output(
330 &mut self,
331 tool_use_id: LanguageModelToolUseId,
332 tool_name: Arc<str>,
333 output: Result<String>,
334 ) -> Option<PendingToolUse> {
335 match output {
336 Ok(tool_result) => {
337 self.tool_results.insert(
338 tool_use_id.clone(),
339 LanguageModelToolResult {
340 tool_use_id: tool_use_id.clone(),
341 tool_name,
342 content: tool_result.into(),
343 is_error: false,
344 },
345 );
346 self.pending_tool_uses_by_id.remove(&tool_use_id)
347 }
348 Err(err) => {
349 self.tool_results.insert(
350 tool_use_id.clone(),
351 LanguageModelToolResult {
352 tool_use_id: tool_use_id.clone(),
353 tool_name,
354 content: err.to_string().into(),
355 is_error: true,
356 },
357 );
358
359 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
360 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
361 }
362
363 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
364 }
365 }
366 }
367
368 pub fn attach_tool_uses(
369 &self,
370 message_id: MessageId,
371 request_message: &mut LanguageModelRequestMessage,
372 ) {
373 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
374 let mut found_tool_use = false;
375
376 for tool_use in tool_uses {
377 if self.tool_results.contains_key(&tool_use.id) {
378 if !found_tool_use {
379 // The API fails if a message contains a tool use without any (non-whitespace) text around it
380 match request_message.content.last_mut() {
381 Some(MessageContent::Text(txt)) => {
382 if txt.is_empty() {
383 txt.push_str(USING_TOOL_MARKER);
384 }
385 }
386 None | Some(_) => {
387 request_message
388 .content
389 .push(MessageContent::Text(USING_TOOL_MARKER.into()));
390 }
391 };
392 }
393
394 found_tool_use = true;
395
396 // Do not send tool uses until they are completed
397 request_message
398 .content
399 .push(MessageContent::ToolUse(tool_use.clone()));
400 } else {
401 log::debug!(
402 "skipped tool use {:?} because it is still pending",
403 tool_use
404 );
405 }
406 }
407 }
408 }
409
410 pub fn attach_tool_results(
411 &self,
412 message_id: MessageId,
413 request_message: &mut LanguageModelRequestMessage,
414 ) {
415 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
416 for tool_use_id in tool_uses {
417 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
418 request_message.content.push(MessageContent::ToolResult(
419 LanguageModelToolResult {
420 tool_use_id: tool_use_id.clone(),
421 tool_name: tool_result.tool_name.clone(),
422 is_error: tool_result.is_error,
423 content: if tool_result.content.is_empty() {
424 // Surprisingly, the API fails if we return an empty string here.
425 // It thinks we are sending a tool use without a tool result.
426 "<Tool returned an empty string>".into()
427 } else {
428 tool_result.content.clone()
429 },
430 },
431 ));
432 }
433 }
434 }
435 }
436}
437
438#[derive(Debug, Clone)]
439pub struct PendingToolUse {
440 pub id: LanguageModelToolUseId,
441 /// The ID of the Assistant message in which the tool use was requested.
442 #[allow(unused)]
443 pub assistant_message_id: MessageId,
444 pub name: Arc<str>,
445 pub ui_text: Arc<str>,
446 pub input: serde_json::Value,
447 pub status: PendingToolUseStatus,
448}
449
450#[derive(Debug, Clone)]
451pub struct Confirmation {
452 pub tool_use_id: LanguageModelToolUseId,
453 pub input: serde_json::Value,
454 pub ui_text: Arc<str>,
455 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
456 pub tool: Arc<dyn Tool>,
457}
458
459#[derive(Debug, Clone)]
460pub enum PendingToolUseStatus {
461 Idle,
462 NeedsConfirmation(Arc<Confirmation>),
463 Running { _task: Shared<Task<()>> },
464 Error(#[allow(unused)] Arc<str>),
465}
466
467impl PendingToolUseStatus {
468 pub fn is_idle(&self) -> bool {
469 matches!(self, PendingToolUseStatus::Idle)
470 }
471
472 pub fn is_error(&self) -> bool {
473 matches!(self, PendingToolUseStatus::Error(_))
474 }
475
476 pub fn needs_confirmation(&self) -> bool {
477 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
478 }
479}