1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashMap;
5use futures::future::Shared;
6use futures::FutureExt as _;
7use gpui::{SharedString, Task};
8use language_model::{
9 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
10 LanguageModelToolUseId, MessageContent,
11};
12
13use crate::thread::MessageId;
14
15#[derive(Debug)]
16pub struct ToolUse {
17 pub id: LanguageModelToolUseId,
18 pub name: SharedString,
19 pub status: ToolUseStatus,
20 pub input: serde_json::Value,
21}
22
23#[derive(Debug, Clone)]
24pub enum ToolUseStatus {
25 Pending,
26 Running,
27 Finished(SharedString),
28 Error(SharedString),
29}
30
31#[derive(Default)]
32pub struct ToolUseState {
33 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
34 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
35 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
36 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
37}
38
39impl ToolUseState {
40 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
41 self.pending_tool_uses_by_id.values().collect()
42 }
43
44 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
45 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
46 return Vec::new();
47 };
48
49 let mut tool_uses = Vec::new();
50
51 for tool_use in tool_uses_for_message.iter() {
52 let tool_result = self.tool_results.get(&tool_use.id);
53
54 let status = (|| {
55 if let Some(tool_result) = tool_result {
56 return if tool_result.is_error {
57 ToolUseStatus::Error(tool_result.content.clone().into())
58 } else {
59 ToolUseStatus::Finished(tool_result.content.clone().into())
60 };
61 }
62
63 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
64 return match pending_tool_use.status {
65 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
66 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
67 PendingToolUseStatus::Error(ref err) => {
68 ToolUseStatus::Error(err.clone().into())
69 }
70 };
71 }
72
73 ToolUseStatus::Pending
74 })();
75
76 tool_uses.push(ToolUse {
77 id: tool_use.id.clone(),
78 name: tool_use.name.clone().into(),
79 input: tool_use.input.clone(),
80 status,
81 })
82 }
83
84 tool_uses
85 }
86
87 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
88 self.tool_uses_by_user_message
89 .get(&message_id)
90 .map_or(false, |results| !results.is_empty())
91 }
92
93 pub fn request_tool_use(
94 &mut self,
95 assistant_message_id: MessageId,
96 tool_use: LanguageModelToolUse,
97 ) {
98 self.tool_uses_by_assistant_message
99 .entry(assistant_message_id)
100 .or_default()
101 .push(tool_use.clone());
102
103 // The tool use is being requested by the Assistant, so we want to
104 // attach the tool results to the next user message.
105 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
106 self.tool_uses_by_user_message
107 .entry(next_user_message_id)
108 .or_default()
109 .push(tool_use.id.clone());
110
111 self.pending_tool_uses_by_id.insert(
112 tool_use.id.clone(),
113 PendingToolUse {
114 assistant_message_id,
115 id: tool_use.id,
116 name: tool_use.name,
117 input: tool_use.input,
118 status: PendingToolUseStatus::Idle,
119 },
120 );
121 }
122
123 pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
124 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
125 tool_use.status = PendingToolUseStatus::Running {
126 _task: task.shared(),
127 };
128 }
129 }
130
131 pub fn insert_tool_output(
132 &mut self,
133 tool_use_id: LanguageModelToolUseId,
134 output: Result<String>,
135 ) {
136 match output {
137 Ok(output) => {
138 self.tool_results.insert(
139 tool_use_id.clone(),
140 LanguageModelToolResult {
141 tool_use_id: tool_use_id.clone(),
142 content: output.into(),
143 is_error: false,
144 },
145 );
146 self.pending_tool_uses_by_id.remove(&tool_use_id);
147 }
148 Err(err) => {
149 self.tool_results.insert(
150 tool_use_id.clone(),
151 LanguageModelToolResult {
152 tool_use_id: tool_use_id.clone(),
153 content: err.to_string().into(),
154 is_error: true,
155 },
156 );
157
158 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
159 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
160 }
161 }
162 }
163 }
164
165 pub fn attach_tool_uses(
166 &self,
167 message_id: MessageId,
168 request_message: &mut LanguageModelRequestMessage,
169 ) {
170 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
171 for tool_use in tool_uses {
172 request_message
173 .content
174 .push(MessageContent::ToolUse(tool_use.clone()));
175 }
176 }
177 }
178
179 pub fn attach_tool_results(
180 &self,
181 message_id: MessageId,
182 request_message: &mut LanguageModelRequestMessage,
183 ) {
184 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
185 for tool_use_id in tool_uses {
186 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
187 request_message
188 .content
189 .push(MessageContent::ToolResult(tool_result.clone()));
190 }
191 }
192 }
193 }
194}
195
196#[derive(Debug, Clone)]
197pub struct PendingToolUse {
198 pub id: LanguageModelToolUseId,
199 /// The ID of the Assistant message in which the tool use was requested.
200 pub assistant_message_id: MessageId,
201 pub name: Arc<str>,
202 pub input: serde_json::Value,
203 pub status: PendingToolUseStatus,
204}
205
206#[derive(Debug, Clone)]
207pub enum PendingToolUseStatus {
208 Idle,
209 Running { _task: Shared<Task<()>> },
210 Error(#[allow(unused)] Arc<str>),
211}
212
213impl PendingToolUseStatus {
214 pub fn is_idle(&self) -> bool {
215 matches!(self, PendingToolUseStatus::Idle)
216 }
217
218 pub fn is_error(&self) -> bool {
219 matches!(self, PendingToolUseStatus::Error(_))
220 }
221}