diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 238f1014017c937f4db22df87c9403e54ef25a55..ee85f3cc94430867ce4405467d72540e21a78bd9 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -8,8 +8,9 @@ use futures::StreamExt as _; use gpui::{App, Context, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId, - MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, + Role, StopReason, }; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; @@ -88,7 +89,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_use: ToolUseState::default(), + tool_use: ToolUseState::new(), } } @@ -99,6 +100,7 @@ impl Thread { _cx: &mut Context, ) -> Self { let next_message_id = MessageId(saved.messages.len()); + let tool_use = ToolUseState::from_saved_messages(&saved.messages); Self { id, @@ -120,7 +122,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_use: ToolUseState::default(), + tool_use, } } @@ -189,6 +191,10 @@ impl Thread { self.tool_use.tool_uses_for_message(id) } + pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { + self.tool_use.tool_results_for_message(id) + } + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_use.message_has_tool_results(message_id) } diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index b9ef30cb56ec00038e1292c2cd1ac4f7a5366e31..f6143186828ccd47160c971611f06540075e1346 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -14,7 +14,7 @@ use gpui::{ }; use heed::types::{SerdeBincode, SerdeJson}; use heed::Database; -use language_model::Role; +use language_model::{LanguageModelToolUseId, Role}; use project::Project; use serde::{Deserialize, Serialize}; use util::ResultExt as _; @@ -113,6 +113,24 @@ impl ThreadStore { id: message.id, role: message.role, text: message.text.clone(), + tool_uses: thread + .tool_uses_for_message(message.id) + .into_iter() + .map(|tool_use| SavedToolUse { + id: tool_use.id, + name: tool_use.name, + input: tool_use.input, + }) + .collect(), + tool_results: thread + .tool_results_for_message(message.id) + .into_iter() + .map(|tool_result| SavedToolResult { + tool_use_id: tool_result.tool_use_id.clone(), + is_error: tool_result.is_error, + content: tool_result.content.clone(), + }) + .collect(), }) .collect(), }; @@ -239,11 +257,29 @@ pub struct SavedThread { pub messages: Vec, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct SavedMessage { pub id: MessageId, pub role: Role, pub text: String, + #[serde(default)] + pub tool_uses: Vec, + #[serde(default)] + pub tool_results: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SavedToolUse { + pub id: LanguageModelToolUseId, + pub name: SharedString, + pub input: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SavedToolResult { + pub tool_use_id: LanguageModelToolUseId, + pub is_error: bool, + pub content: Arc, } struct GlobalThreadsDatabase( diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 12b73554f9db98925028be0a9b6016eb90167884..8340febac18c6f10e18b2094fae949e0c5ffb45f 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -7,10 +7,11 @@ use futures::FutureExt as _; use gpui::{SharedString, Task}; use language_model::{ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, + LanguageModelToolUseId, MessageContent, Role, }; use crate::thread::MessageId; +use crate::thread_store::SavedMessage; #[derive(Debug)] pub struct ToolUse { @@ -28,7 +29,6 @@ pub enum ToolUseStatus { Error(SharedString), } -#[derive(Default)] pub struct ToolUseState { tool_uses_by_assistant_message: HashMap>, tool_uses_by_user_message: HashMap>, @@ -37,6 +37,65 @@ pub struct ToolUseState { } impl ToolUseState { + pub fn new() -> Self { + Self { + tool_uses_by_assistant_message: HashMap::default(), + tool_uses_by_user_message: HashMap::default(), + tool_results: HashMap::default(), + pending_tool_uses_by_id: HashMap::default(), + } + } + + pub fn from_saved_messages(messages: &[SavedMessage]) -> Self { + let mut this = Self::new(); + + for message in messages { + match message.role { + Role::Assistant => { + if !message.tool_uses.is_empty() { + this.tool_uses_by_assistant_message.insert( + message.id, + message + .tool_uses + .iter() + .map(|tool_use| LanguageModelToolUse { + id: tool_use.id.clone(), + name: tool_use.name.clone().into(), + input: tool_use.input.clone(), + }) + .collect(), + ); + } + } + Role::User => { + if !message.tool_results.is_empty() { + let tool_uses_by_user_message = this + .tool_uses_by_user_message + .entry(message.id) + .or_default(); + + for tool_result in &message.tool_results { + let tool_use_id = tool_result.tool_use_id.clone(); + + tool_uses_by_user_message.push(tool_use_id.clone()); + this.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id, + is_error: tool_result.is_error, + content: tool_result.content.clone(), + }, + ); + } + } + } + Role::System => {} + } + } + + this + } + pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { self.pending_tool_uses_by_id.values().collect() } @@ -84,6 +143,17 @@ impl ToolUseState { tool_uses } + pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> { + let empty = Vec::new(); + + self.tool_uses_by_user_message + .get(&message_id) + .unwrap_or(&empty) + .iter() + .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id)) + .collect() + } + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_uses_by_user_message .get(&message_id)