From 968ffaa3fd801b3a436551705db636b0c89609b6 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 25 Nov 2024 21:53:27 -0500 Subject: [PATCH] assistant2: Restructure storage of tool uses and results (#21194) This PR restructures the storage of the tool uses and results in `assistant2` so that they don't live on the individual messages. It also introduces a `LanguageModelToolUseId` newtype for better type safety. Release Notes: - N/A --- Cargo.lock | 1 + crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant/src/context.rs | 21 ++- crates/assistant2/Cargo.toml | 1 + crates/assistant2/src/assistant_panel.rs | 7 +- crates/assistant2/src/thread.rs | 157 +++++++++++------- crates/language_model/src/language_model.rs | 20 ++- crates/language_model/src/request.rs | 2 +- .../language_models/src/provider/anthropic.rs | 2 +- 9 files changed, 136 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a18caa3d11a4180a249de65709a510f41e8a6d6..166adb6588e3b0cda7bdb4ca8a8f284c9833d5db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -465,6 +465,7 @@ dependencies = [ "language_model", "language_model_selector", "proto", + "serde", "serde_json", "settings", "smol", diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index e1ce7c4ab293fb223064256d4282b39297881b44..7467d5dfd482d2e0fc8c9739d271c6bc69238273 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1925,7 +1925,7 @@ impl ContextEditor { Content::ToolUse { range: tool_use.source_range.clone(), tool_use: LanguageModelToolUse { - id: tool_use.id.to_string(), + id: tool_use.id.clone(), name: tool_use.name.clone(), input: tool_use.input.clone(), }, diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index ac032accc3916ec98dd7c415db658b6ad01df2ff..032a66b4c762b2d92f3d76eb6c7d8f25d370479f 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -27,8 +27,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P use language_model::{ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, - StopReason, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, + LanguageModelToolUseId, MessageContent, Role, StopReason, }; use language_models::{ provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}, @@ -385,7 +385,7 @@ pub enum ContextEvent { }, UsePendingTools, ToolFinished { - tool_use_id: Arc, + tool_use_id: LanguageModelToolUseId, output_range: Range, }, Operation(ContextOperation), @@ -479,7 +479,7 @@ pub enum Content { }, ToolResult { range: Range, - tool_use_id: Arc, + tool_use_id: LanguageModelToolUseId, }, } @@ -546,7 +546,7 @@ pub struct Context { pub(crate) slash_commands: Arc, pub(crate) tools: Arc, slash_command_output_sections: Vec>, - pending_tool_uses_by_id: HashMap, PendingToolUse>, + pending_tool_uses_by_id: HashMap, message_anchors: Vec, contents: Vec, messages_metadata: HashMap, @@ -1126,7 +1126,7 @@ impl Context { self.pending_tool_uses_by_id.values().collect() } - pub fn get_tool_use_by_id(&self, id: &Arc) -> Option<&PendingToolUse> { + pub fn get_tool_use_by_id(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> { self.pending_tool_uses_by_id.get(id) } @@ -2153,7 +2153,7 @@ impl Context { pub fn insert_tool_output( &mut self, - tool_use_id: Arc, + tool_use_id: LanguageModelToolUseId, output: Task>, cx: &mut ModelContext, ) { @@ -2340,11 +2340,10 @@ impl Context { let source_range = buffer.anchor_after(start_ix) ..buffer.anchor_after(end_ix); - let tool_use_id: Arc = tool_use.id.into(); this.pending_tool_uses_by_id.insert( - tool_use_id.clone(), + tool_use.id.clone(), PendingToolUse { - id: tool_use_id, + id: tool_use.id, name: tool_use.name, input: tool_use.input, status: PendingToolUseStatus::Idle, @@ -3203,7 +3202,7 @@ pub enum PendingSlashCommandStatus { #[derive(Debug, Clone)] pub struct PendingToolUse { - pub id: Arc, + pub id: LanguageModelToolUseId, pub name: String, pub input: serde_json::Value, pub status: PendingToolUseStatus, diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 60c168079d0bf5e5a26724aa10217a21c29257dc..ca563b05c8d469583c5a51ab09f3473a0678ee43 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -25,6 +25,7 @@ language_model.workspace = true language_model_selector.workspace = true proto.workspace = true settings.workspace = true +serde.workspace = true serde_json.workspace = true smol.workspace = true theme.workspace = true diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 4ebf07e9d43bda37190cfc5acde48e7100848d8c..bf457d6c71826efe8640410504ef63d19f7c731c 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -102,7 +102,12 @@ impl AssistantPanel { let task = tool.run(tool_use.input, self.workspace.clone(), cx); self.thread.update(cx, |thread, cx| { - thread.insert_tool_output(tool_use.id.clone(), task, cx); + thread.insert_tool_output( + tool_use.assistant_message_id, + tool_use.id.clone(), + task, + cx, + ); }); } } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index d8263d15f7c45f860538ebda12ef84773d3657a9..0d2aab6905f62dbe5d9f5643843a64703ebfec6f 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -8,8 +8,10 @@ use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, EventEmitter, ModelContext, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, + LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, + StopReason, }; +use serde::{Deserialize, Serialize}; use util::post_inc; #[derive(Debug, Clone, Copy)] @@ -17,34 +19,46 @@ pub enum RequestKind { Chat, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +pub struct MessageId(usize); + +impl MessageId { + fn post_inc(&mut self) -> Self { + Self(post_inc(&mut self.0)) + } +} + /// A message in a [`Thread`]. #[derive(Debug, Clone)] pub struct Message { + pub id: MessageId, pub role: Role, pub text: String, - pub tool_uses: Vec, - pub tool_results: Vec, } /// A thread of conversation with the LLM. pub struct Thread { messages: Vec, + next_message_id: MessageId, completion_count: usize, pending_completions: Vec, tools: Arc, - pending_tool_uses_by_id: HashMap, PendingToolUse>, - completed_tool_uses_by_id: HashMap, String>, + tool_uses_by_message: HashMap>, + tool_results_by_message: HashMap>, + pending_tool_uses_by_id: HashMap, } impl Thread { pub fn new(tools: Arc, _cx: &mut ModelContext) -> Self { Self { - tools, messages: Vec::new(), + next_message_id: MessageId(0), completion_count: 0, pending_completions: Vec::new(), + tools, + tool_uses_by_message: HashMap::default(), + tool_results_by_message: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), - completed_tool_uses_by_id: HashMap::default(), } } @@ -61,22 +75,11 @@ impl Thread { } pub fn insert_user_message(&mut self, text: impl Into) { - let mut message = Message { + self.messages.push(Message { + id: self.next_message_id.post_inc(), role: Role::User, text: text.into(), - tool_uses: Vec::new(), - tool_results: Vec::new(), - }; - - for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() { - message.tool_results.push(LanguageModelToolResult { - tool_use_id: tool_use_id.to_string(), - content: tool_output, - is_error: false, - }); - } - - self.messages.push(message); + }); } pub fn to_completion_request( @@ -98,10 +101,12 @@ impl Thread { cache: false, }; - for tool_result in &message.tool_results { - request_message - .content - .push(MessageContent::ToolResult(tool_result.clone())); + if let Some(tool_results) = self.tool_results_by_message.get(&message.id) { + for tool_result in tool_results { + request_message + .content + .push(MessageContent::ToolResult(tool_result.clone())); + } } if !message.text.is_empty() { @@ -110,10 +115,12 @@ impl Thread { .push(MessageContent::Text(message.text.clone())); } - for tool_use in &message.tool_uses { - request_message - .content - .push(MessageContent::ToolUse(tool_use.clone())); + if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) { + for tool_use in tool_uses { + request_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } } request.messages.push(request_message); @@ -143,10 +150,9 @@ impl Thread { match event { LanguageModelCompletionEvent::StartMessage { .. } => { thread.messages.push(Message { + id: thread.next_message_id.post_inc(), role: Role::Assistant, text: String::new(), - tool_uses: Vec::new(), - tool_results: Vec::new(), }); } LanguageModelCompletionEvent::Stop(reason) => { @@ -160,22 +166,28 @@ impl Thread { } } LanguageModelCompletionEvent::ToolUse(tool_use) => { - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant { - last_message.tool_uses.push(tool_use.clone()); - } + if let Some(last_assistant_message) = thread + .messages + .iter() + .rfind(|message| message.role == Role::Assistant) + { + thread + .tool_uses_by_message + .entry(last_assistant_message.id) + .or_default() + .push(tool_use.clone()); + + thread.pending_tool_uses_by_id.insert( + tool_use.id.clone(), + PendingToolUse { + assistant_message_id: last_assistant_message.id, + id: tool_use.id, + name: tool_use.name, + input: tool_use.input, + status: PendingToolUseStatus::Idle, + }, + ); } - - let tool_use_id: Arc = tool_use.id.into(); - thread.pending_tool_uses_by_id.insert( - tool_use_id.clone(), - PendingToolUse { - id: tool_use_id, - name: tool_use.name, - input: tool_use.input, - status: PendingToolUseStatus::Idle, - }, - ); } } @@ -235,7 +247,8 @@ impl Thread { pub fn insert_tool_output( &mut self, - tool_use_id: Arc, + assistant_message_id: MessageId, + tool_use_id: LanguageModelToolUseId, output: Task>, cx: &mut ModelContext, ) { @@ -244,19 +257,39 @@ impl Thread { async move { let output = output.await; thread - .update(&mut cx, |thread, cx| match output { - Ok(output) => { - thread - .completed_tool_uses_by_id - .insert(tool_use_id.clone(), output); + .update(&mut cx, |thread, cx| { + // The tool use was requested by an Assistant message, + // so we want to attach the tool results to the next + // user message. + let next_user_message = MessageId(assistant_message_id.0 + 1); + + let tool_results = thread + .tool_results_by_message + .entry(next_user_message) + .or_default(); + + match output { + Ok(output) => { + tool_results.push(LanguageModelToolResult { + tool_use_id: tool_use_id.to_string(), + content: output, + is_error: false, + }); - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); - } - Err(err) => { - if let Some(tool_use) = - thread.pending_tool_uses_by_id.get_mut(&tool_use_id) - { - tool_use.status = PendingToolUseStatus::Error(err.to_string()); + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + } + Err(err) => { + tool_results.push(LanguageModelToolResult { + tool_use_id: tool_use_id.to_string(), + content: err.to_string(), + is_error: true, + }); + + if let Some(tool_use) = + thread.pending_tool_uses_by_id.get_mut(&tool_use_id) + { + tool_use.status = PendingToolUseStatus::Error(err.to_string()); + } } } }) @@ -278,7 +311,7 @@ pub enum ThreadEvent { UsePendingTools, ToolFinished { #[allow(unused)] - tool_use_id: Arc, + tool_use_id: LanguageModelToolUseId, }, } @@ -291,7 +324,9 @@ struct PendingCompletion { #[derive(Debug, Clone)] pub struct PendingToolUse { - pub id: Arc, + pub id: LanguageModelToolUseId, + /// The ID of the Assistant message in which the tool use was requested. + pub assistant_message_id: MessageId, pub name: String, pub input: serde_json::Value, pub status: PendingToolUseStatus, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index f9df34a2d1707c63ab5a59b49a0965706f3dcf5b..3c5a00bd85e682fa7f53b747e786d25ded013249 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -63,9 +63,27 @@ pub enum StopReason { ToolUse, } +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUseId(Arc); + +impl fmt::Display for LanguageModelToolUseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelToolUseId +where + T: Into>, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUse { - pub id: String, + pub id: LanguageModelToolUseId, pub name: String, pub input: serde_json::Value, } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 06dde1862ab37ed2a4fbec4a8e67cb1bd18254cf..e6f7f210c73f7253df1bf3a3b7f9c7a425258f5b 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -347,7 +347,7 @@ impl LanguageModelRequest { } MessageContent::ToolUse(tool_use) => { Some(anthropic::RequestContent::ToolUse { - id: tool_use.id, + id: tool_use.id.to_string(), name: tool_use.name, input: tool_use.input, cache_control, diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 87460b824eb24062bddc48890da55b78a29449c9..e882bb900de06bfa0feb5d711e8de2cb4b011596 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -498,7 +498,7 @@ pub fn map_to_language_model_completion_events( Some(maybe!({ Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { - id: tool_use.id, + id: tool_use.id.into(), name: tool_use.name, input: if tool_use.input_json.is_empty() { serde_json::Value::Null