1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::future::Shared;
7use futures::FutureExt as _;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25 pub icon: ui::IconName,
26 pub needs_confirmation: bool,
27}
28
29#[derive(Debug, Clone)]
30pub enum ToolUseStatus {
31 NeedsConfirmation,
32 Pending,
33 Running,
34 Finished(SharedString),
35 Error(SharedString),
36}
37
38pub struct ToolUseState {
39 tools: Arc<ToolWorkingSet>,
40 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
41 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
42 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
43 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
44}
45
46impl ToolUseState {
47 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
48 Self {
49 tools,
50 tool_uses_by_assistant_message: HashMap::default(),
51 tool_uses_by_user_message: HashMap::default(),
52 tool_results: HashMap::default(),
53 pending_tool_uses_by_id: HashMap::default(),
54 }
55 }
56
57 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
58 ///
59 /// Accepts a function to filter the tools that should be used to populate the state.
60 pub fn from_serialized_messages(
61 tools: Arc<ToolWorkingSet>,
62 messages: &[SerializedMessage],
63 mut filter_by_tool_name: impl FnMut(&str) -> bool,
64 ) -> Self {
65 let mut this = Self::new(tools);
66 let mut tool_names_by_id = HashMap::default();
67
68 for message in messages {
69 match message.role {
70 Role::Assistant => {
71 if !message.tool_uses.is_empty() {
72 let tool_uses = message
73 .tool_uses
74 .iter()
75 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
76 .map(|tool_use| LanguageModelToolUse {
77 id: tool_use.id.clone(),
78 name: tool_use.name.clone().into(),
79 input: tool_use.input.clone(),
80 })
81 .collect::<Vec<_>>();
82
83 tool_names_by_id.extend(
84 tool_uses
85 .iter()
86 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
87 );
88
89 this.tool_uses_by_assistant_message
90 .insert(message.id, tool_uses);
91 }
92 }
93 Role::User => {
94 if !message.tool_results.is_empty() {
95 let tool_uses_by_user_message = this
96 .tool_uses_by_user_message
97 .entry(message.id)
98 .or_default();
99
100 for tool_result in &message.tool_results {
101 let tool_use_id = tool_result.tool_use_id.clone();
102 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
103 log::warn!("no tool name found for tool use: {tool_use_id:?}");
104 continue;
105 };
106
107 if !(filter_by_tool_name)(tool_use.as_ref()) {
108 continue;
109 }
110
111 tool_uses_by_user_message.push(tool_use_id.clone());
112 this.tool_results.insert(
113 tool_use_id.clone(),
114 LanguageModelToolResult {
115 tool_use_id,
116 is_error: tool_result.is_error,
117 content: tool_result.content.clone(),
118 },
119 );
120 }
121 }
122 }
123 Role::System => {}
124 }
125 }
126
127 this
128 }
129
130 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
131 let mut pending_tools = Vec::new();
132 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
133 self.tool_results.insert(
134 tool_use_id.clone(),
135 LanguageModelToolResult {
136 tool_use_id,
137 content: "Tool canceled by user".into(),
138 is_error: true,
139 },
140 );
141 pending_tools.push(tool_use.clone());
142 }
143 pending_tools
144 }
145
146 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
147 self.pending_tool_uses_by_id.values().collect()
148 }
149
150 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
151 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
152 return Vec::new();
153 };
154
155 let mut tool_uses = Vec::new();
156
157 for tool_use in tool_uses_for_message.iter() {
158 let tool_result = self.tool_results.get(&tool_use.id);
159
160 let status = (|| {
161 if let Some(tool_result) = tool_result {
162 return if tool_result.is_error {
163 ToolUseStatus::Error(tool_result.content.clone().into())
164 } else {
165 ToolUseStatus::Finished(tool_result.content.clone().into())
166 };
167 }
168
169 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
170 match pending_tool_use.status {
171 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
172 PendingToolUseStatus::NeedsConfirmation { .. } => {
173 ToolUseStatus::NeedsConfirmation
174 }
175 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
176 PendingToolUseStatus::Error(ref err) => {
177 ToolUseStatus::Error(err.clone().into())
178 }
179 }
180 } else {
181 ToolUseStatus::Pending
182 }
183 })();
184
185 let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
186 {
187 (tool.icon(), tool.needs_confirmation())
188 } else {
189 (IconName::Cog, false)
190 };
191
192 tool_uses.push(ToolUse {
193 id: tool_use.id.clone(),
194 name: tool_use.name.clone().into(),
195 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
196 input: tool_use.input.clone(),
197 status,
198 icon,
199 needs_confirmation,
200 })
201 }
202
203 tool_uses
204 }
205
206 pub fn tool_ui_label(
207 &self,
208 tool_name: &str,
209 input: &serde_json::Value,
210 cx: &App,
211 ) -> SharedString {
212 if let Some(tool) = self.tools.tool(tool_name, cx) {
213 tool.ui_text(input).into()
214 } else {
215 format!("Unknown tool {tool_name:?}").into()
216 }
217 }
218
219 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
220 let empty = Vec::new();
221
222 self.tool_uses_by_user_message
223 .get(&message_id)
224 .unwrap_or(&empty)
225 .iter()
226 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
227 .collect()
228 }
229
230 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
231 self.tool_uses_by_user_message
232 .get(&message_id)
233 .map_or(false, |results| !results.is_empty())
234 }
235
236 pub fn tool_result(
237 &self,
238 tool_use_id: &LanguageModelToolUseId,
239 ) -> Option<&LanguageModelToolResult> {
240 self.tool_results.get(tool_use_id)
241 }
242
243 pub fn request_tool_use(
244 &mut self,
245 assistant_message_id: MessageId,
246 tool_use: LanguageModelToolUse,
247 cx: &App,
248 ) {
249 self.tool_uses_by_assistant_message
250 .entry(assistant_message_id)
251 .or_default()
252 .push(tool_use.clone());
253
254 // The tool use is being requested by the Assistant, so we want to
255 // attach the tool results to the next user message.
256 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
257 self.tool_uses_by_user_message
258 .entry(next_user_message_id)
259 .or_default()
260 .push(tool_use.id.clone());
261
262 self.pending_tool_uses_by_id.insert(
263 tool_use.id.clone(),
264 PendingToolUse {
265 assistant_message_id,
266 id: tool_use.id,
267 name: tool_use.name.clone(),
268 ui_text: self
269 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
270 .into(),
271 input: tool_use.input,
272 status: PendingToolUseStatus::Idle,
273 },
274 );
275 }
276
277 pub fn run_pending_tool(
278 &mut self,
279 tool_use_id: LanguageModelToolUseId,
280 ui_text: SharedString,
281 task: Task<()>,
282 ) {
283 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
284 tool_use.ui_text = ui_text.into();
285 tool_use.status = PendingToolUseStatus::Running {
286 _task: task.shared(),
287 };
288 }
289 }
290
291 pub fn confirm_tool_use(
292 &mut self,
293 tool_use_id: LanguageModelToolUseId,
294 ui_text: impl Into<Arc<str>>,
295 input: serde_json::Value,
296 messages: Arc<Vec<LanguageModelRequestMessage>>,
297 tool: Arc<dyn Tool>,
298 ) {
299 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
300 let ui_text = ui_text.into();
301 tool_use.ui_text = ui_text.clone();
302 let confirmation = Confirmation {
303 tool_use_id,
304 input,
305 messages,
306 tool,
307 ui_text,
308 };
309 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
310 }
311 }
312
313 pub fn insert_tool_output(
314 &mut self,
315 tool_use_id: LanguageModelToolUseId,
316 output: Result<String>,
317 ) -> Option<PendingToolUse> {
318 match output {
319 Ok(tool_result) => {
320 self.tool_results.insert(
321 tool_use_id.clone(),
322 LanguageModelToolResult {
323 tool_use_id: tool_use_id.clone(),
324 content: tool_result.into(),
325 is_error: false,
326 },
327 );
328 self.pending_tool_uses_by_id.remove(&tool_use_id)
329 }
330 Err(err) => {
331 self.tool_results.insert(
332 tool_use_id.clone(),
333 LanguageModelToolResult {
334 tool_use_id: tool_use_id.clone(),
335 content: err.to_string().into(),
336 is_error: true,
337 },
338 );
339
340 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
341 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
342 }
343
344 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
345 }
346 }
347 }
348
349 pub fn attach_tool_uses(
350 &self,
351 message_id: MessageId,
352 request_message: &mut LanguageModelRequestMessage,
353 ) {
354 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
355 for tool_use in tool_uses {
356 if self.tool_results.contains_key(&tool_use.id) {
357 // Do not send tool uses until they are completed
358 request_message
359 .content
360 .push(MessageContent::ToolUse(tool_use.clone()));
361 } else {
362 log::debug!(
363 "skipped tool use {:?} because it is still pending",
364 tool_use
365 );
366 }
367 }
368 }
369 }
370
371 pub fn attach_tool_results(
372 &self,
373 message_id: MessageId,
374 request_message: &mut LanguageModelRequestMessage,
375 ) {
376 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
377 for tool_use_id in tool_uses {
378 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
379 request_message.content.push(MessageContent::ToolResult(
380 LanguageModelToolResult {
381 tool_use_id: tool_use_id.clone(),
382 is_error: tool_result.is_error,
383 content: if tool_result.content.is_empty() {
384 // Surprisingly, the API fails if we return an empty string here.
385 // It thinks we are sending a tool use without a tool result.
386 "<Tool returned an empty string>".into()
387 } else {
388 tool_result.content.clone()
389 },
390 },
391 ));
392 }
393 }
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
399pub struct PendingToolUse {
400 pub id: LanguageModelToolUseId,
401 /// The ID of the Assistant message in which the tool use was requested.
402 #[allow(unused)]
403 pub assistant_message_id: MessageId,
404 pub name: Arc<str>,
405 pub ui_text: Arc<str>,
406 pub input: serde_json::Value,
407 pub status: PendingToolUseStatus,
408}
409
410#[derive(Debug, Clone)]
411pub struct Confirmation {
412 pub tool_use_id: LanguageModelToolUseId,
413 pub input: serde_json::Value,
414 pub ui_text: Arc<str>,
415 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
416 pub tool: Arc<dyn Tool>,
417}
418
419#[derive(Debug, Clone)]
420pub enum PendingToolUseStatus {
421 Idle,
422 NeedsConfirmation(Arc<Confirmation>),
423 Running { _task: Shared<Task<()>> },
424 Error(#[allow(unused)] Arc<str>),
425}
426
427impl PendingToolUseStatus {
428 pub fn is_idle(&self) -> bool {
429 matches!(self, PendingToolUseStatus::Idle)
430 }
431
432 pub fn is_error(&self) -> bool {
433 matches!(self, PendingToolUseStatus::Error(_))
434 }
435
436 pub fn needs_confirmation(&self) -> bool {
437 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
438 }
439}