1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::future::Shared;
7use futures::FutureExt as _;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13
14use crate::thread::MessageId;
15use crate::thread_store::SerializedMessage;
16
17#[derive(Debug)]
18pub struct ToolUse {
19 pub id: LanguageModelToolUseId,
20 pub name: SharedString,
21 pub ui_text: SharedString,
22 pub status: ToolUseStatus,
23 pub input: serde_json::Value,
24}
25
26#[derive(Debug, Clone)]
27pub enum ToolUseStatus {
28 NeedsConfirmation,
29 Pending,
30 Running,
31 Finished(SharedString),
32 Error(SharedString),
33}
34
35pub struct ToolUseState {
36 tools: Arc<ToolWorkingSet>,
37 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
38 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
39 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
40 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
41}
42
43impl ToolUseState {
44 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
45 Self {
46 tools,
47 tool_uses_by_assistant_message: HashMap::default(),
48 tool_uses_by_user_message: HashMap::default(),
49 tool_results: HashMap::default(),
50 pending_tool_uses_by_id: HashMap::default(),
51 }
52 }
53
54 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
55 ///
56 /// Accepts a function to filter the tools that should be used to populate the state.
57 pub fn from_serialized_messages(
58 tools: Arc<ToolWorkingSet>,
59 messages: &[SerializedMessage],
60 mut filter_by_tool_name: impl FnMut(&str) -> bool,
61 ) -> Self {
62 let mut this = Self::new(tools);
63 let mut tool_names_by_id = HashMap::default();
64
65 for message in messages {
66 match message.role {
67 Role::Assistant => {
68 if !message.tool_uses.is_empty() {
69 let tool_uses = message
70 .tool_uses
71 .iter()
72 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
73 .map(|tool_use| LanguageModelToolUse {
74 id: tool_use.id.clone(),
75 name: tool_use.name.clone().into(),
76 input: tool_use.input.clone(),
77 })
78 .collect::<Vec<_>>();
79
80 tool_names_by_id.extend(
81 tool_uses
82 .iter()
83 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
84 );
85
86 this.tool_uses_by_assistant_message
87 .insert(message.id, tool_uses);
88 }
89 }
90 Role::User => {
91 if !message.tool_results.is_empty() {
92 let tool_uses_by_user_message = this
93 .tool_uses_by_user_message
94 .entry(message.id)
95 .or_default();
96
97 for tool_result in &message.tool_results {
98 let tool_use_id = tool_result.tool_use_id.clone();
99 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
100 log::warn!("no tool name found for tool use: {tool_use_id:?}");
101 continue;
102 };
103
104 if !(filter_by_tool_name)(tool_use.as_ref()) {
105 continue;
106 }
107
108 tool_uses_by_user_message.push(tool_use_id.clone());
109 this.tool_results.insert(
110 tool_use_id.clone(),
111 LanguageModelToolResult {
112 tool_use_id,
113 is_error: tool_result.is_error,
114 content: tool_result.content.clone(),
115 },
116 );
117 }
118 }
119 }
120 Role::System => {}
121 }
122 }
123
124 this
125 }
126
127 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128 let mut pending_tools = Vec::new();
129 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
130 self.tool_results.insert(
131 tool_use_id.clone(),
132 LanguageModelToolResult {
133 tool_use_id,
134 content: "Tool canceled by user".into(),
135 is_error: true,
136 },
137 );
138 pending_tools.push(tool_use.clone());
139 }
140 pending_tools
141 }
142
143 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
144 self.pending_tool_uses_by_id.values().collect()
145 }
146
147 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
148 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
149 return Vec::new();
150 };
151
152 let mut tool_uses = Vec::new();
153
154 for tool_use in tool_uses_for_message.iter() {
155 let tool_result = self.tool_results.get(&tool_use.id);
156
157 let status = (|| {
158 if let Some(tool_result) = tool_result {
159 return if tool_result.is_error {
160 ToolUseStatus::Error(tool_result.content.clone().into())
161 } else {
162 ToolUseStatus::Finished(tool_result.content.clone().into())
163 };
164 }
165
166 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
167 match pending_tool_use.status {
168 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
169 PendingToolUseStatus::NeedsConfirmation { .. } => {
170 ToolUseStatus::NeedsConfirmation
171 }
172 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
173 PendingToolUseStatus::Error(ref err) => {
174 ToolUseStatus::Error(err.clone().into())
175 }
176 }
177 } else {
178 ToolUseStatus::Pending
179 }
180 })();
181
182 tool_uses.push(ToolUse {
183 id: tool_use.id.clone(),
184 name: tool_use.name.clone().into(),
185 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
186 input: tool_use.input.clone(),
187 status,
188 })
189 }
190
191 tool_uses
192 }
193
194 pub fn tool_ui_label(
195 &self,
196 tool_name: &str,
197 input: &serde_json::Value,
198 cx: &App,
199 ) -> SharedString {
200 if let Some(tool) = self.tools.tool(tool_name, cx) {
201 tool.ui_text(input).into()
202 } else {
203 "Unknown tool".into()
204 }
205 }
206
207 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
208 let empty = Vec::new();
209
210 self.tool_uses_by_user_message
211 .get(&message_id)
212 .unwrap_or(&empty)
213 .iter()
214 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
215 .collect()
216 }
217
218 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
219 self.tool_uses_by_user_message
220 .get(&message_id)
221 .map_or(false, |results| !results.is_empty())
222 }
223
224 pub fn tool_result(
225 &self,
226 tool_use_id: &LanguageModelToolUseId,
227 ) -> Option<&LanguageModelToolResult> {
228 self.tool_results.get(tool_use_id)
229 }
230
231 pub fn request_tool_use(
232 &mut self,
233 assistant_message_id: MessageId,
234 tool_use: LanguageModelToolUse,
235 cx: &App,
236 ) {
237 self.tool_uses_by_assistant_message
238 .entry(assistant_message_id)
239 .or_default()
240 .push(tool_use.clone());
241
242 // The tool use is being requested by the Assistant, so we want to
243 // attach the tool results to the next user message.
244 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
245 self.tool_uses_by_user_message
246 .entry(next_user_message_id)
247 .or_default()
248 .push(tool_use.id.clone());
249
250 self.pending_tool_uses_by_id.insert(
251 tool_use.id.clone(),
252 PendingToolUse {
253 assistant_message_id,
254 id: tool_use.id,
255 name: tool_use.name.clone(),
256 ui_text: self
257 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
258 .into(),
259 input: tool_use.input,
260 status: PendingToolUseStatus::Idle,
261 },
262 );
263 }
264
265 pub fn run_pending_tool(
266 &mut self,
267 tool_use_id: LanguageModelToolUseId,
268 ui_text: SharedString,
269 task: Task<()>,
270 ) {
271 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
272 tool_use.ui_text = ui_text.into();
273 tool_use.status = PendingToolUseStatus::Running {
274 _task: task.shared(),
275 };
276 }
277 }
278
279 pub fn confirm_tool_use(
280 &mut self,
281 tool_use_id: LanguageModelToolUseId,
282 ui_text: impl Into<Arc<str>>,
283 input: serde_json::Value,
284 messages: Arc<Vec<LanguageModelRequestMessage>>,
285 tool: Arc<dyn Tool>,
286 ) {
287 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
288 let ui_text = ui_text.into();
289 tool_use.ui_text = ui_text.clone();
290 let confirmation = Confirmation {
291 tool_use_id,
292 input,
293 messages,
294 tool,
295 ui_text,
296 };
297 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
298 }
299 }
300
301 pub fn insert_tool_output(
302 &mut self,
303 tool_use_id: LanguageModelToolUseId,
304 output: Result<String>,
305 ) -> Option<PendingToolUse> {
306 match output {
307 Ok(tool_result) => {
308 self.tool_results.insert(
309 tool_use_id.clone(),
310 LanguageModelToolResult {
311 tool_use_id: tool_use_id.clone(),
312 content: tool_result.into(),
313 is_error: false,
314 },
315 );
316 self.pending_tool_uses_by_id.remove(&tool_use_id)
317 }
318 Err(err) => {
319 self.tool_results.insert(
320 tool_use_id.clone(),
321 LanguageModelToolResult {
322 tool_use_id: tool_use_id.clone(),
323 content: err.to_string().into(),
324 is_error: true,
325 },
326 );
327
328 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
329 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
330 }
331
332 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
333 }
334 }
335 }
336
337 pub fn attach_tool_uses(
338 &self,
339 message_id: MessageId,
340 request_message: &mut LanguageModelRequestMessage,
341 ) {
342 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
343 for tool_use in tool_uses {
344 if self.tool_results.contains_key(&tool_use.id) {
345 // Do not send tool uses until they are completed
346 request_message
347 .content
348 .push(MessageContent::ToolUse(tool_use.clone()));
349 } else {
350 log::debug!(
351 "skipped tool use {:?} because it is still pending",
352 tool_use
353 );
354 }
355 }
356 }
357 }
358
359 pub fn attach_tool_results(
360 &self,
361 message_id: MessageId,
362 request_message: &mut LanguageModelRequestMessage,
363 ) {
364 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
365 for tool_use_id in tool_uses {
366 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
367 request_message.content.push(MessageContent::ToolResult(
368 LanguageModelToolResult {
369 tool_use_id: tool_use_id.clone(),
370 is_error: tool_result.is_error,
371 content: if tool_result.content.is_empty() {
372 // Surprisingly, the API fails if we return an empty string here.
373 // It thinks we are sending a tool use without a tool result.
374 "<Tool returned an empty string>".into()
375 } else {
376 tool_result.content.clone()
377 },
378 },
379 ));
380 }
381 }
382 }
383 }
384}
385
386#[derive(Debug, Clone)]
387pub struct PendingToolUse {
388 pub id: LanguageModelToolUseId,
389 /// The ID of the Assistant message in which the tool use was requested.
390 #[allow(unused)]
391 pub assistant_message_id: MessageId,
392 pub name: Arc<str>,
393 pub ui_text: Arc<str>,
394 pub input: serde_json::Value,
395 pub status: PendingToolUseStatus,
396}
397
398#[derive(Debug, Clone)]
399pub struct Confirmation {
400 pub tool_use_id: LanguageModelToolUseId,
401 pub input: serde_json::Value,
402 pub ui_text: Arc<str>,
403 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
404 pub tool: Arc<dyn Tool>,
405}
406
407#[derive(Debug, Clone)]
408pub enum PendingToolUseStatus {
409 Idle,
410 NeedsConfirmation(Arc<Confirmation>),
411 Running { _task: Shared<Task<()>> },
412 Error(#[allow(unused)] Arc<str>),
413}
414
415impl PendingToolUseStatus {
416 pub fn is_idle(&self) -> bool {
417 matches!(self, PendingToolUseStatus::Idle)
418 }
419
420 pub fn is_error(&self) -> bool {
421 matches!(self, PendingToolUseStatus::Error(_))
422 }
423
424 pub fn needs_confirmation(&self) -> bool {
425 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
426 }
427}