tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashMap;
  5use futures::future::Shared;
  6use futures::FutureExt as _;
  7use gpui::{SharedString, Task};
  8use language_model::{
  9    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 10    LanguageModelToolUseId, MessageContent, Role,
 11};
 12
 13use crate::thread::MessageId;
 14use crate::thread_store::SavedMessage;
 15
 16#[derive(Debug)]
 17pub struct ToolUse {
 18    pub id: LanguageModelToolUseId,
 19    pub name: SharedString,
 20    pub status: ToolUseStatus,
 21    pub input: serde_json::Value,
 22}
 23
 24#[derive(Debug, Clone)]
 25pub enum ToolUseStatus {
 26    Pending,
 27    Running,
 28    Finished(SharedString),
 29    Error(SharedString),
 30}
 31
 32pub struct ToolUseState {
 33    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 34    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 35    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 36    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 37}
 38
 39impl ToolUseState {
 40    pub fn new() -> Self {
 41        Self {
 42            tool_uses_by_assistant_message: HashMap::default(),
 43            tool_uses_by_user_message: HashMap::default(),
 44            tool_results: HashMap::default(),
 45            pending_tool_uses_by_id: HashMap::default(),
 46        }
 47    }
 48
 49    pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
 50        let mut this = Self::new();
 51
 52        for message in messages {
 53            match message.role {
 54                Role::Assistant => {
 55                    if !message.tool_uses.is_empty() {
 56                        this.tool_uses_by_assistant_message.insert(
 57                            message.id,
 58                            message
 59                                .tool_uses
 60                                .iter()
 61                                .map(|tool_use| LanguageModelToolUse {
 62                                    id: tool_use.id.clone(),
 63                                    name: tool_use.name.clone().into(),
 64                                    input: tool_use.input.clone(),
 65                                })
 66                                .collect(),
 67                        );
 68                    }
 69                }
 70                Role::User => {
 71                    if !message.tool_results.is_empty() {
 72                        let tool_uses_by_user_message = this
 73                            .tool_uses_by_user_message
 74                            .entry(message.id)
 75                            .or_default();
 76
 77                        for tool_result in &message.tool_results {
 78                            let tool_use_id = tool_result.tool_use_id.clone();
 79
 80                            tool_uses_by_user_message.push(tool_use_id.clone());
 81                            this.tool_results.insert(
 82                                tool_use_id.clone(),
 83                                LanguageModelToolResult {
 84                                    tool_use_id,
 85                                    is_error: tool_result.is_error,
 86                                    content: tool_result.content.clone(),
 87                                },
 88                            );
 89                        }
 90                    }
 91                }
 92                Role::System => {}
 93            }
 94        }
 95
 96        this
 97    }
 98
 99    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
100        self.pending_tool_uses_by_id.values().collect()
101    }
102
103    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
104        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
105            return Vec::new();
106        };
107
108        let mut tool_uses = Vec::new();
109
110        for tool_use in tool_uses_for_message.iter() {
111            let tool_result = self.tool_results.get(&tool_use.id);
112
113            let status = (|| {
114                if let Some(tool_result) = tool_result {
115                    return if tool_result.is_error {
116                        ToolUseStatus::Error(tool_result.content.clone().into())
117                    } else {
118                        ToolUseStatus::Finished(tool_result.content.clone().into())
119                    };
120                }
121
122                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
123                    return match pending_tool_use.status {
124                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
125                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
126                        PendingToolUseStatus::Error(ref err) => {
127                            ToolUseStatus::Error(err.clone().into())
128                        }
129                    };
130                }
131
132                ToolUseStatus::Pending
133            })();
134
135            tool_uses.push(ToolUse {
136                id: tool_use.id.clone(),
137                name: tool_use.name.clone().into(),
138                input: tool_use.input.clone(),
139                status,
140            })
141        }
142
143        tool_uses
144    }
145
146    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
147        let empty = Vec::new();
148
149        self.tool_uses_by_user_message
150            .get(&message_id)
151            .unwrap_or(&empty)
152            .iter()
153            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
154            .collect()
155    }
156
157    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
158        self.tool_uses_by_user_message
159            .get(&message_id)
160            .map_or(false, |results| !results.is_empty())
161    }
162
163    pub fn request_tool_use(
164        &mut self,
165        assistant_message_id: MessageId,
166        tool_use: LanguageModelToolUse,
167    ) {
168        self.tool_uses_by_assistant_message
169            .entry(assistant_message_id)
170            .or_default()
171            .push(tool_use.clone());
172
173        // The tool use is being requested by the Assistant, so we want to
174        // attach the tool results to the next user message.
175        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
176        self.tool_uses_by_user_message
177            .entry(next_user_message_id)
178            .or_default()
179            .push(tool_use.id.clone());
180
181        self.pending_tool_uses_by_id.insert(
182            tool_use.id.clone(),
183            PendingToolUse {
184                assistant_message_id,
185                id: tool_use.id,
186                name: tool_use.name,
187                input: tool_use.input,
188                status: PendingToolUseStatus::Idle,
189            },
190        );
191    }
192
193    pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
194        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
195            tool_use.status = PendingToolUseStatus::Running {
196                _task: task.shared(),
197            };
198        }
199    }
200
201    pub fn insert_tool_output(
202        &mut self,
203        tool_use_id: LanguageModelToolUseId,
204        output: Result<String>,
205    ) {
206        match output {
207            Ok(output) => {
208                self.tool_results.insert(
209                    tool_use_id.clone(),
210                    LanguageModelToolResult {
211                        tool_use_id: tool_use_id.clone(),
212                        content: output.into(),
213                        is_error: false,
214                    },
215                );
216                self.pending_tool_uses_by_id.remove(&tool_use_id);
217            }
218            Err(err) => {
219                self.tool_results.insert(
220                    tool_use_id.clone(),
221                    LanguageModelToolResult {
222                        tool_use_id: tool_use_id.clone(),
223                        content: err.to_string().into(),
224                        is_error: true,
225                    },
226                );
227
228                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
229                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
230                }
231            }
232        }
233    }
234
235    pub fn attach_tool_uses(
236        &self,
237        message_id: MessageId,
238        request_message: &mut LanguageModelRequestMessage,
239    ) {
240        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
241            for tool_use in tool_uses {
242                request_message
243                    .content
244                    .push(MessageContent::ToolUse(tool_use.clone()));
245            }
246        }
247    }
248
249    pub fn attach_tool_results(
250        &self,
251        message_id: MessageId,
252        request_message: &mut LanguageModelRequestMessage,
253    ) {
254        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
255            for tool_use_id in tool_uses {
256                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
257                    request_message
258                        .content
259                        .push(MessageContent::ToolResult(tool_result.clone()));
260                }
261            }
262        }
263    }
264}
265
266#[derive(Debug, Clone)]
267pub struct PendingToolUse {
268    pub id: LanguageModelToolUseId,
269    /// The ID of the Assistant message in which the tool use was requested.
270    pub assistant_message_id: MessageId,
271    pub name: Arc<str>,
272    pub input: serde_json::Value,
273    pub status: PendingToolUseStatus,
274}
275
276#[derive(Debug, Clone)]
277pub enum PendingToolUseStatus {
278    Idle,
279    Running { _task: Shared<Task<()>> },
280    Error(#[allow(unused)] Arc<str>),
281}
282
283impl PendingToolUseStatus {
284    pub fn is_idle(&self) -> bool {
285        matches!(self, PendingToolUseStatus::Idle)
286    }
287
288    pub fn is_error(&self) -> bool {
289        matches!(self, PendingToolUseStatus::Error(_))
290    }
291}