1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashMap;
5use futures::future::Shared;
6use futures::FutureExt as _;
7use gpui::{SharedString, Task};
8use language_model::{
9 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
10 LanguageModelToolUseId, MessageContent, Role,
11};
12
13use crate::thread::MessageId;
14use crate::thread_store::SerializedMessage;
15
16#[derive(Debug)]
17pub struct ToolUse {
18 pub id: LanguageModelToolUseId,
19 pub name: SharedString,
20 pub status: ToolUseStatus,
21 pub input: serde_json::Value,
22}
23
24#[derive(Debug, Clone)]
25pub enum ToolUseStatus {
26 Pending,
27 Running,
28 Finished(SharedString),
29 Error(SharedString),
30}
31
32pub struct ToolUseState {
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
35 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
36 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
37}
38
39impl ToolUseState {
40 pub fn new() -> Self {
41 Self {
42 tool_uses_by_assistant_message: HashMap::default(),
43 tool_uses_by_user_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 }
47 }
48
49 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
50 ///
51 /// Accepts a function to filter the tools that should be used to populate the state.
52 pub fn from_serialized_messages(
53 messages: &[SerializedMessage],
54 mut filter_by_tool_name: impl FnMut(&str) -> bool,
55 ) -> Self {
56 let mut this = Self::new();
57 let mut tool_names_by_id = HashMap::default();
58
59 for message in messages {
60 match message.role {
61 Role::Assistant => {
62 if !message.tool_uses.is_empty() {
63 let tool_uses = message
64 .tool_uses
65 .iter()
66 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
67 .map(|tool_use| LanguageModelToolUse {
68 id: tool_use.id.clone(),
69 name: tool_use.name.clone().into(),
70 input: tool_use.input.clone(),
71 })
72 .collect::<Vec<_>>();
73
74 tool_names_by_id.extend(
75 tool_uses
76 .iter()
77 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
78 );
79
80 this.tool_uses_by_assistant_message
81 .insert(message.id, tool_uses);
82 }
83 }
84 Role::User => {
85 if !message.tool_results.is_empty() {
86 let tool_uses_by_user_message = this
87 .tool_uses_by_user_message
88 .entry(message.id)
89 .or_default();
90
91 for tool_result in &message.tool_results {
92 let tool_use_id = tool_result.tool_use_id.clone();
93 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
94 log::warn!("no tool name found for tool use: {tool_use_id:?}");
95 continue;
96 };
97
98 if !(filter_by_tool_name)(tool_use.as_ref()) {
99 continue;
100 }
101
102 tool_uses_by_user_message.push(tool_use_id.clone());
103 this.tool_results.insert(
104 tool_use_id.clone(),
105 LanguageModelToolResult {
106 tool_use_id,
107 is_error: tool_result.is_error,
108 content: tool_result.content.clone(),
109 },
110 );
111 }
112 }
113 }
114 Role::System => {}
115 }
116 }
117
118 this
119 }
120
121 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
122 self.pending_tool_uses_by_id.values().collect()
123 }
124
125 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
126 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
127 return Vec::new();
128 };
129
130 let mut tool_uses = Vec::new();
131
132 for tool_use in tool_uses_for_message.iter() {
133 let tool_result = self.tool_results.get(&tool_use.id);
134
135 let status = (|| {
136 if let Some(tool_result) = tool_result {
137 return if tool_result.is_error {
138 ToolUseStatus::Error(tool_result.content.clone().into())
139 } else {
140 ToolUseStatus::Finished(tool_result.content.clone().into())
141 };
142 }
143
144 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
145 return match pending_tool_use.status {
146 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
147 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
148 PendingToolUseStatus::Error(ref err) => {
149 ToolUseStatus::Error(err.clone().into())
150 }
151 };
152 }
153
154 ToolUseStatus::Pending
155 })();
156
157 tool_uses.push(ToolUse {
158 id: tool_use.id.clone(),
159 name: tool_use.name.clone().into(),
160 input: tool_use.input.clone(),
161 status,
162 })
163 }
164
165 tool_uses
166 }
167
168 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
169 let empty = Vec::new();
170
171 self.tool_uses_by_user_message
172 .get(&message_id)
173 .unwrap_or(&empty)
174 .iter()
175 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
176 .collect()
177 }
178
179 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
180 self.tool_uses_by_user_message
181 .get(&message_id)
182 .map_or(false, |results| !results.is_empty())
183 }
184
185 pub fn request_tool_use(
186 &mut self,
187 assistant_message_id: MessageId,
188 tool_use: LanguageModelToolUse,
189 ) {
190 self.tool_uses_by_assistant_message
191 .entry(assistant_message_id)
192 .or_default()
193 .push(tool_use.clone());
194
195 // The tool use is being requested by the Assistant, so we want to
196 // attach the tool results to the next user message.
197 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
198 self.tool_uses_by_user_message
199 .entry(next_user_message_id)
200 .or_default()
201 .push(tool_use.id.clone());
202
203 self.pending_tool_uses_by_id.insert(
204 tool_use.id.clone(),
205 PendingToolUse {
206 assistant_message_id,
207 id: tool_use.id,
208 name: tool_use.name,
209 input: tool_use.input,
210 status: PendingToolUseStatus::Idle,
211 },
212 );
213 }
214
215 pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
216 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
217 tool_use.status = PendingToolUseStatus::Running {
218 _task: task.shared(),
219 };
220 }
221 }
222
223 pub fn insert_tool_output(
224 &mut self,
225 tool_use_id: LanguageModelToolUseId,
226 output: Result<String>,
227 ) -> Option<PendingToolUse> {
228 match output {
229 Ok(tool_result) => {
230 self.tool_results.insert(
231 tool_use_id.clone(),
232 LanguageModelToolResult {
233 tool_use_id: tool_use_id.clone(),
234 content: tool_result.into(),
235 is_error: false,
236 },
237 );
238 self.pending_tool_uses_by_id.remove(&tool_use_id)
239 }
240 Err(err) => {
241 self.tool_results.insert(
242 tool_use_id.clone(),
243 LanguageModelToolResult {
244 tool_use_id: tool_use_id.clone(),
245 content: err.to_string().into(),
246 is_error: true,
247 },
248 );
249
250 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
251 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
252 }
253
254 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
255 }
256 }
257 }
258
259 pub fn attach_tool_uses(
260 &self,
261 message_id: MessageId,
262 request_message: &mut LanguageModelRequestMessage,
263 ) {
264 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
265 for tool_use in tool_uses {
266 request_message
267 .content
268 .push(MessageContent::ToolUse(tool_use.clone()));
269 }
270 }
271 }
272
273 pub fn attach_tool_results(
274 &self,
275 message_id: MessageId,
276 request_message: &mut LanguageModelRequestMessage,
277 ) {
278 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
279 for tool_use_id in tool_uses {
280 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
281 request_message
282 .content
283 .push(MessageContent::ToolResult(tool_result.clone()));
284 }
285 }
286 }
287 }
288}
289
290#[derive(Debug, Clone)]
291pub struct PendingToolUse {
292 pub id: LanguageModelToolUseId,
293 /// The ID of the Assistant message in which the tool use was requested.
294 #[allow(unused)]
295 pub assistant_message_id: MessageId,
296 pub name: Arc<str>,
297 pub input: serde_json::Value,
298 pub status: PendingToolUseStatus,
299}
300
301#[derive(Debug, Clone)]
302pub enum PendingToolUseStatus {
303 Idle,
304 Running { _task: Shared<Task<()>> },
305 Error(#[allow(unused)] Arc<str>),
306}
307
308impl PendingToolUseStatus {
309 pub fn is_idle(&self) -> bool {
310 matches!(self, PendingToolUseStatus::Idle)
311 }
312
313 pub fn is_error(&self) -> bool {
314 matches!(self, PendingToolUseStatus::Error(_))
315 }
316}