1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::future::Shared;
7use futures::FutureExt as _;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use ui::IconName;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25 pub icon: ui::IconName,
26}
27
28#[derive(Debug, Clone)]
29pub enum ToolUseStatus {
30 NeedsConfirmation,
31 Pending,
32 Running,
33 Finished(SharedString),
34 Error(SharedString),
35}
36
37pub struct ToolUseState {
38 tools: Arc<ToolWorkingSet>,
39 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
40 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
41 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
42 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
43}
44
45impl ToolUseState {
46 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
47 Self {
48 tools,
49 tool_uses_by_assistant_message: HashMap::default(),
50 tool_uses_by_user_message: HashMap::default(),
51 tool_results: HashMap::default(),
52 pending_tool_uses_by_id: HashMap::default(),
53 }
54 }
55
56 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
57 ///
58 /// Accepts a function to filter the tools that should be used to populate the state.
59 pub fn from_serialized_messages(
60 tools: Arc<ToolWorkingSet>,
61 messages: &[SerializedMessage],
62 mut filter_by_tool_name: impl FnMut(&str) -> bool,
63 ) -> Self {
64 let mut this = Self::new(tools);
65 let mut tool_names_by_id = HashMap::default();
66
67 for message in messages {
68 match message.role {
69 Role::Assistant => {
70 if !message.tool_uses.is_empty() {
71 let tool_uses = message
72 .tool_uses
73 .iter()
74 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
75 .map(|tool_use| LanguageModelToolUse {
76 id: tool_use.id.clone(),
77 name: tool_use.name.clone().into(),
78 input: tool_use.input.clone(),
79 })
80 .collect::<Vec<_>>();
81
82 tool_names_by_id.extend(
83 tool_uses
84 .iter()
85 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
86 );
87
88 this.tool_uses_by_assistant_message
89 .insert(message.id, tool_uses);
90 }
91 }
92 Role::User => {
93 if !message.tool_results.is_empty() {
94 let tool_uses_by_user_message = this
95 .tool_uses_by_user_message
96 .entry(message.id)
97 .or_default();
98
99 for tool_result in &message.tool_results {
100 let tool_use_id = tool_result.tool_use_id.clone();
101 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
102 log::warn!("no tool name found for tool use: {tool_use_id:?}");
103 continue;
104 };
105
106 if !(filter_by_tool_name)(tool_use.as_ref()) {
107 continue;
108 }
109
110 tool_uses_by_user_message.push(tool_use_id.clone());
111 this.tool_results.insert(
112 tool_use_id.clone(),
113 LanguageModelToolResult {
114 tool_use_id,
115 is_error: tool_result.is_error,
116 content: tool_result.content.clone(),
117 },
118 );
119 }
120 }
121 }
122 Role::System => {}
123 }
124 }
125
126 this
127 }
128
129 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130 let mut pending_tools = Vec::new();
131 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132 self.tool_results.insert(
133 tool_use_id.clone(),
134 LanguageModelToolResult {
135 tool_use_id,
136 content: "Tool canceled by user".into(),
137 is_error: true,
138 },
139 );
140 pending_tools.push(tool_use.clone());
141 }
142 pending_tools
143 }
144
145 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
146 self.pending_tool_uses_by_id.values().collect()
147 }
148
149 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
150 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
151 return Vec::new();
152 };
153
154 let mut tool_uses = Vec::new();
155
156 for tool_use in tool_uses_for_message.iter() {
157 let tool_result = self.tool_results.get(&tool_use.id);
158
159 let status = (|| {
160 if let Some(tool_result) = tool_result {
161 return if tool_result.is_error {
162 ToolUseStatus::Error(tool_result.content.clone().into())
163 } else {
164 ToolUseStatus::Finished(tool_result.content.clone().into())
165 };
166 }
167
168 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
169 match pending_tool_use.status {
170 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
171 PendingToolUseStatus::NeedsConfirmation { .. } => {
172 ToolUseStatus::NeedsConfirmation
173 }
174 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
175 PendingToolUseStatus::Error(ref err) => {
176 ToolUseStatus::Error(err.clone().into())
177 }
178 }
179 } else {
180 ToolUseStatus::Pending
181 }
182 })();
183
184 let icon = if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
185 tool.icon()
186 } else {
187 IconName::Cog
188 };
189
190 tool_uses.push(ToolUse {
191 id: tool_use.id.clone(),
192 name: tool_use.name.clone().into(),
193 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
194 input: tool_use.input.clone(),
195 status,
196 icon,
197 })
198 }
199
200 tool_uses
201 }
202
203 pub fn tool_ui_label(
204 &self,
205 tool_name: &str,
206 input: &serde_json::Value,
207 cx: &App,
208 ) -> SharedString {
209 if let Some(tool) = self.tools.tool(tool_name, cx) {
210 tool.ui_text(input).into()
211 } else {
212 format!("Unknown tool {tool_name:?}").into()
213 }
214 }
215
216 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
217 let empty = Vec::new();
218
219 self.tool_uses_by_user_message
220 .get(&message_id)
221 .unwrap_or(&empty)
222 .iter()
223 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
224 .collect()
225 }
226
227 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
228 self.tool_uses_by_user_message
229 .get(&message_id)
230 .map_or(false, |results| !results.is_empty())
231 }
232
233 pub fn tool_result(
234 &self,
235 tool_use_id: &LanguageModelToolUseId,
236 ) -> Option<&LanguageModelToolResult> {
237 self.tool_results.get(tool_use_id)
238 }
239
240 pub fn request_tool_use(
241 &mut self,
242 assistant_message_id: MessageId,
243 tool_use: LanguageModelToolUse,
244 cx: &App,
245 ) {
246 self.tool_uses_by_assistant_message
247 .entry(assistant_message_id)
248 .or_default()
249 .push(tool_use.clone());
250
251 // The tool use is being requested by the Assistant, so we want to
252 // attach the tool results to the next user message.
253 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
254 self.tool_uses_by_user_message
255 .entry(next_user_message_id)
256 .or_default()
257 .push(tool_use.id.clone());
258
259 self.pending_tool_uses_by_id.insert(
260 tool_use.id.clone(),
261 PendingToolUse {
262 assistant_message_id,
263 id: tool_use.id,
264 name: tool_use.name.clone(),
265 ui_text: self
266 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
267 .into(),
268 input: tool_use.input,
269 status: PendingToolUseStatus::Idle,
270 },
271 );
272 }
273
274 pub fn run_pending_tool(
275 &mut self,
276 tool_use_id: LanguageModelToolUseId,
277 ui_text: SharedString,
278 task: Task<()>,
279 ) {
280 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
281 tool_use.ui_text = ui_text.into();
282 tool_use.status = PendingToolUseStatus::Running {
283 _task: task.shared(),
284 };
285 }
286 }
287
288 pub fn confirm_tool_use(
289 &mut self,
290 tool_use_id: LanguageModelToolUseId,
291 ui_text: impl Into<Arc<str>>,
292 input: serde_json::Value,
293 messages: Arc<Vec<LanguageModelRequestMessage>>,
294 tool: Arc<dyn Tool>,
295 ) {
296 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
297 let ui_text = ui_text.into();
298 tool_use.ui_text = ui_text.clone();
299 let confirmation = Confirmation {
300 tool_use_id,
301 input,
302 messages,
303 tool,
304 ui_text,
305 };
306 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
307 }
308 }
309
310 pub fn insert_tool_output(
311 &mut self,
312 tool_use_id: LanguageModelToolUseId,
313 output: Result<String>,
314 ) -> Option<PendingToolUse> {
315 match output {
316 Ok(tool_result) => {
317 self.tool_results.insert(
318 tool_use_id.clone(),
319 LanguageModelToolResult {
320 tool_use_id: tool_use_id.clone(),
321 content: tool_result.into(),
322 is_error: false,
323 },
324 );
325 self.pending_tool_uses_by_id.remove(&tool_use_id)
326 }
327 Err(err) => {
328 self.tool_results.insert(
329 tool_use_id.clone(),
330 LanguageModelToolResult {
331 tool_use_id: tool_use_id.clone(),
332 content: err.to_string().into(),
333 is_error: true,
334 },
335 );
336
337 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
338 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
339 }
340
341 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
342 }
343 }
344 }
345
346 pub fn attach_tool_uses(
347 &self,
348 message_id: MessageId,
349 request_message: &mut LanguageModelRequestMessage,
350 ) {
351 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
352 for tool_use in tool_uses {
353 if self.tool_results.contains_key(&tool_use.id) {
354 // Do not send tool uses until they are completed
355 request_message
356 .content
357 .push(MessageContent::ToolUse(tool_use.clone()));
358 } else {
359 log::debug!(
360 "skipped tool use {:?} because it is still pending",
361 tool_use
362 );
363 }
364 }
365 }
366 }
367
368 pub fn attach_tool_results(
369 &self,
370 message_id: MessageId,
371 request_message: &mut LanguageModelRequestMessage,
372 ) {
373 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
374 for tool_use_id in tool_uses {
375 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
376 request_message.content.push(MessageContent::ToolResult(
377 LanguageModelToolResult {
378 tool_use_id: tool_use_id.clone(),
379 is_error: tool_result.is_error,
380 content: if tool_result.content.is_empty() {
381 // Surprisingly, the API fails if we return an empty string here.
382 // It thinks we are sending a tool use without a tool result.
383 "<Tool returned an empty string>".into()
384 } else {
385 tool_result.content.clone()
386 },
387 },
388 ));
389 }
390 }
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
396pub struct PendingToolUse {
397 pub id: LanguageModelToolUseId,
398 /// The ID of the Assistant message in which the tool use was requested.
399 #[allow(unused)]
400 pub assistant_message_id: MessageId,
401 pub name: Arc<str>,
402 pub ui_text: Arc<str>,
403 pub input: serde_json::Value,
404 pub status: PendingToolUseStatus,
405}
406
407#[derive(Debug, Clone)]
408pub struct Confirmation {
409 pub tool_use_id: LanguageModelToolUseId,
410 pub input: serde_json::Value,
411 pub ui_text: Arc<str>,
412 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
413 pub tool: Arc<dyn Tool>,
414}
415
416#[derive(Debug, Clone)]
417pub enum PendingToolUseStatus {
418 Idle,
419 NeedsConfirmation(Arc<Confirmation>),
420 Running { _task: Shared<Task<()>> },
421 Error(#[allow(unused)] Arc<str>),
422}
423
424impl PendingToolUseStatus {
425 pub fn is_idle(&self) -> bool {
426 matches!(self, PendingToolUseStatus::Idle)
427 }
428
429 pub fn is_error(&self) -> bool {
430 matches!(self, PendingToolUseStatus::Error(_))
431 }
432
433 pub fn needs_confirmation(&self) -> bool {
434 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
435 }
436}