Detailed changes
@@ -8,8 +8,9 @@ use futures::StreamExt as _;
use gpui::{App, Context, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
- MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
+ Role, StopReason,
};
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
@@ -88,7 +89,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_use: ToolUseState::default(),
+ tool_use: ToolUseState::new(),
}
}
@@ -99,6 +100,7 @@ impl Thread {
_cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(saved.messages.len());
+ let tool_use = ToolUseState::from_saved_messages(&saved.messages);
Self {
id,
@@ -120,7 +122,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_use: ToolUseState::default(),
+ tool_use,
}
}
@@ -189,6 +191,10 @@ impl Thread {
self.tool_use.tool_uses_for_message(id)
}
+ pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
+ self.tool_use.tool_results_for_message(id)
+ }
+
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@@ -14,7 +14,7 @@ use gpui::{
};
use heed::types::{SerdeBincode, SerdeJson};
use heed::Database;
-use language_model::Role;
+use language_model::{LanguageModelToolUseId, Role};
use project::Project;
use serde::{Deserialize, Serialize};
use util::ResultExt as _;
@@ -113,6 +113,24 @@ impl ThreadStore {
id: message.id,
role: message.role,
text: message.text.clone(),
+ tool_uses: thread
+ .tool_uses_for_message(message.id)
+ .into_iter()
+ .map(|tool_use| SavedToolUse {
+ id: tool_use.id,
+ name: tool_use.name,
+ input: tool_use.input,
+ })
+ .collect(),
+ tool_results: thread
+ .tool_results_for_message(message.id)
+ .into_iter()
+ .map(|tool_result| SavedToolResult {
+ tool_use_id: tool_result.tool_use_id.clone(),
+ is_error: tool_result.is_error,
+ content: tool_result.content.clone(),
+ })
+ .collect(),
})
.collect(),
};
@@ -239,11 +257,29 @@ pub struct SavedThread {
pub messages: Vec<SavedMessage>,
}
-#[derive(Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
+ #[serde(default)]
+ pub tool_uses: Vec<SavedToolUse>,
+ #[serde(default)]
+ pub tool_results: Vec<SavedToolResult>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SavedToolUse {
+ pub id: LanguageModelToolUseId,
+ pub name: SharedString,
+ pub input: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SavedToolResult {
+ pub tool_use_id: LanguageModelToolUseId,
+ pub is_error: bool,
+ pub content: Arc<str>,
}
struct GlobalThreadsDatabase(
@@ -7,10 +7,11 @@ use futures::FutureExt as _;
use gpui::{SharedString, Task};
use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
- LanguageModelToolUseId, MessageContent,
+ LanguageModelToolUseId, MessageContent, Role,
};
use crate::thread::MessageId;
+use crate::thread_store::SavedMessage;
#[derive(Debug)]
pub struct ToolUse {
@@ -28,7 +29,6 @@ pub enum ToolUseStatus {
Error(SharedString),
}
-#[derive(Default)]
pub struct ToolUseState {
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
@@ -37,6 +37,65 @@ pub struct ToolUseState {
}
impl ToolUseState {
+ pub fn new() -> Self {
+ Self {
+ tool_uses_by_assistant_message: HashMap::default(),
+ tool_uses_by_user_message: HashMap::default(),
+ tool_results: HashMap::default(),
+ pending_tool_uses_by_id: HashMap::default(),
+ }
+ }
+
+ pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
+ let mut this = Self::new();
+
+ for message in messages {
+ match message.role {
+ Role::Assistant => {
+ if !message.tool_uses.is_empty() {
+ this.tool_uses_by_assistant_message.insert(
+ message.id,
+ message
+ .tool_uses
+ .iter()
+ .map(|tool_use| LanguageModelToolUse {
+ id: tool_use.id.clone(),
+ name: tool_use.name.clone().into(),
+ input: tool_use.input.clone(),
+ })
+ .collect(),
+ );
+ }
+ }
+ Role::User => {
+ if !message.tool_results.is_empty() {
+ let tool_uses_by_user_message = this
+ .tool_uses_by_user_message
+ .entry(message.id)
+ .or_default();
+
+ for tool_result in &message.tool_results {
+ let tool_use_id = tool_result.tool_use_id.clone();
+
+ tool_uses_by_user_message.push(tool_use_id.clone());
+ this.tool_results.insert(
+ tool_use_id.clone(),
+ LanguageModelToolResult {
+ tool_use_id,
+ is_error: tool_result.is_error,
+ content: tool_result.content.clone(),
+ },
+ );
+ }
+ }
+ }
+ Role::System => {}
+ }
+ }
+
+ this
+ }
+
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
@@ -84,6 +143,17 @@ impl ToolUseState {
tool_uses
}
+ pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
+ let empty = Vec::new();
+
+ self.tool_uses_by_user_message
+ .get(&message_id)
+ .unwrap_or(&empty)
+ .iter()
+ .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
+ .collect()
+ }
+
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_uses_by_user_message
.get(&message_id)