1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashMap;
5use futures::future::Shared;
6use futures::FutureExt as _;
7use gpui::{SharedString, Task};
8use language_model::{
9 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
10 LanguageModelToolUseId, MessageContent, Role,
11};
12
13use crate::thread::MessageId;
14use crate::thread_store::SavedMessage;
15
16#[derive(Debug)]
17pub struct ToolUse {
18 pub id: LanguageModelToolUseId,
19 pub name: SharedString,
20 pub status: ToolUseStatus,
21 pub input: serde_json::Value,
22}
23
24#[derive(Debug, Clone)]
25pub enum ToolUseStatus {
26 Pending,
27 Running,
28 Finished(SharedString),
29 Error(SharedString),
30}
31
32pub struct ToolUseState {
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
35 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
36 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
37}
38
39impl ToolUseState {
40 pub fn new() -> Self {
41 Self {
42 tool_uses_by_assistant_message: HashMap::default(),
43 tool_uses_by_user_message: HashMap::default(),
44 tool_results: HashMap::default(),
45 pending_tool_uses_by_id: HashMap::default(),
46 }
47 }
48
49 pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
50 let mut this = Self::new();
51
52 for message in messages {
53 match message.role {
54 Role::Assistant => {
55 if !message.tool_uses.is_empty() {
56 this.tool_uses_by_assistant_message.insert(
57 message.id,
58 message
59 .tool_uses
60 .iter()
61 .map(|tool_use| LanguageModelToolUse {
62 id: tool_use.id.clone(),
63 name: tool_use.name.clone().into(),
64 input: tool_use.input.clone(),
65 })
66 .collect(),
67 );
68 }
69 }
70 Role::User => {
71 if !message.tool_results.is_empty() {
72 let tool_uses_by_user_message = this
73 .tool_uses_by_user_message
74 .entry(message.id)
75 .or_default();
76
77 for tool_result in &message.tool_results {
78 let tool_use_id = tool_result.tool_use_id.clone();
79
80 tool_uses_by_user_message.push(tool_use_id.clone());
81 this.tool_results.insert(
82 tool_use_id.clone(),
83 LanguageModelToolResult {
84 tool_use_id,
85 is_error: tool_result.is_error,
86 content: tool_result.content.clone(),
87 },
88 );
89 }
90 }
91 }
92 Role::System => {}
93 }
94 }
95
96 this
97 }
98
99 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
100 self.pending_tool_uses_by_id.values().collect()
101 }
102
103 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
104 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
105 return Vec::new();
106 };
107
108 let mut tool_uses = Vec::new();
109
110 for tool_use in tool_uses_for_message.iter() {
111 let tool_result = self.tool_results.get(&tool_use.id);
112
113 let status = (|| {
114 if let Some(tool_result) = tool_result {
115 return if tool_result.is_error {
116 ToolUseStatus::Error(tool_result.content.clone().into())
117 } else {
118 ToolUseStatus::Finished(tool_result.content.clone().into())
119 };
120 }
121
122 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
123 return match pending_tool_use.status {
124 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
125 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
126 PendingToolUseStatus::Error(ref err) => {
127 ToolUseStatus::Error(err.clone().into())
128 }
129 };
130 }
131
132 ToolUseStatus::Pending
133 })();
134
135 tool_uses.push(ToolUse {
136 id: tool_use.id.clone(),
137 name: tool_use.name.clone().into(),
138 input: tool_use.input.clone(),
139 status,
140 })
141 }
142
143 tool_uses
144 }
145
146 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
147 let empty = Vec::new();
148
149 self.tool_uses_by_user_message
150 .get(&message_id)
151 .unwrap_or(&empty)
152 .iter()
153 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
154 .collect()
155 }
156
157 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
158 self.tool_uses_by_user_message
159 .get(&message_id)
160 .map_or(false, |results| !results.is_empty())
161 }
162
163 pub fn request_tool_use(
164 &mut self,
165 assistant_message_id: MessageId,
166 tool_use: LanguageModelToolUse,
167 ) {
168 self.tool_uses_by_assistant_message
169 .entry(assistant_message_id)
170 .or_default()
171 .push(tool_use.clone());
172
173 // The tool use is being requested by the Assistant, so we want to
174 // attach the tool results to the next user message.
175 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
176 self.tool_uses_by_user_message
177 .entry(next_user_message_id)
178 .or_default()
179 .push(tool_use.id.clone());
180
181 self.pending_tool_uses_by_id.insert(
182 tool_use.id.clone(),
183 PendingToolUse {
184 assistant_message_id,
185 id: tool_use.id,
186 name: tool_use.name,
187 input: tool_use.input,
188 status: PendingToolUseStatus::Idle,
189 },
190 );
191 }
192
193 pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
194 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
195 tool_use.status = PendingToolUseStatus::Running {
196 _task: task.shared(),
197 };
198 }
199 }
200
201 pub fn insert_tool_output(
202 &mut self,
203 tool_use_id: LanguageModelToolUseId,
204 output: Result<String>,
205 ) {
206 match output {
207 Ok(output) => {
208 self.tool_results.insert(
209 tool_use_id.clone(),
210 LanguageModelToolResult {
211 tool_use_id: tool_use_id.clone(),
212 content: output.into(),
213 is_error: false,
214 },
215 );
216 self.pending_tool_uses_by_id.remove(&tool_use_id);
217 }
218 Err(err) => {
219 self.tool_results.insert(
220 tool_use_id.clone(),
221 LanguageModelToolResult {
222 tool_use_id: tool_use_id.clone(),
223 content: err.to_string().into(),
224 is_error: true,
225 },
226 );
227
228 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
229 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
230 }
231 }
232 }
233 }
234
235 pub fn attach_tool_uses(
236 &self,
237 message_id: MessageId,
238 request_message: &mut LanguageModelRequestMessage,
239 ) {
240 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
241 for tool_use in tool_uses {
242 request_message
243 .content
244 .push(MessageContent::ToolUse(tool_use.clone()));
245 }
246 }
247 }
248
249 pub fn attach_tool_results(
250 &self,
251 message_id: MessageId,
252 request_message: &mut LanguageModelRequestMessage,
253 ) {
254 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
255 for tool_use_id in tool_uses {
256 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
257 request_message
258 .content
259 .push(MessageContent::ToolResult(tool_result.clone()));
260 }
261 }
262 }
263 }
264}
265
266#[derive(Debug, Clone)]
267pub struct PendingToolUse {
268 pub id: LanguageModelToolUseId,
269 /// The ID of the Assistant message in which the tool use was requested.
270 pub assistant_message_id: MessageId,
271 pub name: Arc<str>,
272 pub input: serde_json::Value,
273 pub status: PendingToolUseStatus,
274}
275
276#[derive(Debug, Clone)]
277pub enum PendingToolUseStatus {
278 Idle,
279 Running { _task: Shared<Task<()>> },
280 Error(#[allow(unused)] Arc<str>),
281}
282
283impl PendingToolUseStatus {
284 pub fn is_idle(&self) -> bool {
285 matches!(self, PendingToolUseStatus::Idle)
286 }
287
288 pub fn is_error(&self) -> bool {
289 matches!(self, PendingToolUseStatus::Error(_))
290 }
291}