assistant2: Restructure storage of tool uses and results (#21194)

Marshall Bowers created

This PR restructures the storage of the tool uses and results in
`assistant2` so that they don't live on the individual messages.

It also introduces a `LanguageModelToolUseId` newtype for better type
safety.

Release Notes:

- N/A

Change summary

Cargo.lock                                       |   1 
crates/assistant/src/assistant_panel.rs          |   2 
crates/assistant/src/context.rs                  |  21 +-
crates/assistant2/Cargo.toml                     |   1 
crates/assistant2/src/assistant_panel.rs         |   7 
crates/assistant2/src/thread.rs                  | 157 +++++++++++------
crates/language_model/src/language_model.rs      |  20 ++
crates/language_model/src/request.rs             |   2 
crates/language_models/src/provider/anthropic.rs |   2 
9 files changed, 136 insertions(+), 77 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -465,6 +465,7 @@ dependencies = [
  "language_model",
  "language_model_selector",
  "proto",
+ "serde",
  "serde_json",
  "settings",
  "smol",

crates/assistant/src/assistant_panel.rs 🔗

@@ -1925,7 +1925,7 @@ impl ContextEditor {
                                     Content::ToolUse {
                                         range: tool_use.source_range.clone(),
                                         tool_use: LanguageModelToolUse {
-                                            id: tool_use.id.to_string(),
+                                            id: tool_use.id.clone(),
                                             name: tool_use.name.clone(),
                                             input: tool_use.input.clone(),
                                         },

crates/assistant/src/context.rs 🔗

@@ -27,8 +27,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
 use language_model::{
     LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
     LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
-    StopReason,
+    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
+    LanguageModelToolUseId, MessageContent, Role, StopReason,
 };
 use language_models::{
     provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
@@ -385,7 +385,7 @@ pub enum ContextEvent {
     },
     UsePendingTools,
     ToolFinished {
-        tool_use_id: Arc<str>,
+        tool_use_id: LanguageModelToolUseId,
         output_range: Range<language::Anchor>,
     },
     Operation(ContextOperation),
@@ -479,7 +479,7 @@ pub enum Content {
     },
     ToolResult {
         range: Range<language::Anchor>,
-        tool_use_id: Arc<str>,
+        tool_use_id: LanguageModelToolUseId,
     },
 }
 
@@ -546,7 +546,7 @@ pub struct Context {
     pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
     pub(crate) tools: Arc<ToolWorkingSet>,
     slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
-    pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
+    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
     message_anchors: Vec<MessageAnchor>,
     contents: Vec<Content>,
     messages_metadata: HashMap<MessageId, MessageMetadata>,
@@ -1126,7 +1126,7 @@ impl Context {
         self.pending_tool_uses_by_id.values().collect()
     }
 
-    pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
+    pub fn get_tool_use_by_id(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
         self.pending_tool_uses_by_id.get(id)
     }
 
@@ -2153,7 +2153,7 @@ impl Context {
 
     pub fn insert_tool_output(
         &mut self,
-        tool_use_id: Arc<str>,
+        tool_use_id: LanguageModelToolUseId,
         output: Task<Result<String>>,
         cx: &mut ModelContext<Self>,
     ) {
@@ -2340,11 +2340,10 @@ impl Context {
                                         let source_range = buffer.anchor_after(start_ix)
                                             ..buffer.anchor_after(end_ix);
 
-                                        let tool_use_id: Arc<str> = tool_use.id.into();
                                         this.pending_tool_uses_by_id.insert(
-                                            tool_use_id.clone(),
+                                            tool_use.id.clone(),
                                             PendingToolUse {
-                                                id: tool_use_id,
+                                                id: tool_use.id,
                                                 name: tool_use.name,
                                                 input: tool_use.input,
                                                 status: PendingToolUseStatus::Idle,
@@ -3203,7 +3202,7 @@ pub enum PendingSlashCommandStatus {
 
 #[derive(Debug, Clone)]
 pub struct PendingToolUse {
-    pub id: Arc<str>,
+    pub id: LanguageModelToolUseId,
     pub name: String,
     pub input: serde_json::Value,
     pub status: PendingToolUseStatus,

crates/assistant2/Cargo.toml 🔗

@@ -25,6 +25,7 @@ language_model.workspace = true
 language_model_selector.workspace = true
 proto.workspace = true
 settings.workspace = true
+serde.workspace = true
 serde_json.workspace = true
 smol.workspace = true
 theme.workspace = true

crates/assistant2/src/assistant_panel.rs 🔗

@@ -102,7 +102,12 @@ impl AssistantPanel {
                         let task = tool.run(tool_use.input, self.workspace.clone(), cx);
 
                         self.thread.update(cx, |thread, cx| {
-                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
+                            thread.insert_tool_output(
+                                tool_use.assistant_message_id,
+                                tool_use.id.clone(),
+                                task,
+                                cx,
+                            );
                         });
                     }
                 }

crates/assistant2/src/thread.rs 🔗

@@ -8,8 +8,10 @@ use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, EventEmitter, ModelContext, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
+    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
+    StopReason,
 };
+use serde::{Deserialize, Serialize};
 use util::post_inc;
 
 #[derive(Debug, Clone, Copy)]
@@ -17,34 +19,46 @@ pub enum RequestKind {
     Chat,
 }
 
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+pub struct MessageId(usize);
+
+impl MessageId {
+    fn post_inc(&mut self) -> Self {
+        Self(post_inc(&mut self.0))
+    }
+}
+
 /// A message in a [`Thread`].
 #[derive(Debug, Clone)]
 pub struct Message {
+    pub id: MessageId,
     pub role: Role,
     pub text: String,
-    pub tool_uses: Vec<LanguageModelToolUse>,
-    pub tool_results: Vec<LanguageModelToolResult>,
 }
 
 /// A thread of conversation with the LLM.
 pub struct Thread {
     messages: Vec<Message>,
+    next_message_id: MessageId,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     tools: Arc<ToolWorkingSet>,
-    pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
-    completed_tool_uses_by_id: HashMap<Arc<str>, String>,
+    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
+    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 }
 
 impl Thread {
     pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
         Self {
-            tools,
             messages: Vec::new(),
+            next_message_id: MessageId(0),
             completion_count: 0,
             pending_completions: Vec::new(),
+            tools,
+            tool_uses_by_message: HashMap::default(),
+            tool_results_by_message: HashMap::default(),
             pending_tool_uses_by_id: HashMap::default(),
-            completed_tool_uses_by_id: HashMap::default(),
         }
     }
 
@@ -61,22 +75,11 @@ impl Thread {
     }
 
     pub fn insert_user_message(&mut self, text: impl Into<String>) {
-        let mut message = Message {
+        self.messages.push(Message {
+            id: self.next_message_id.post_inc(),
             role: Role::User,
             text: text.into(),
-            tool_uses: Vec::new(),
-            tool_results: Vec::new(),
-        };
-
-        for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
-            message.tool_results.push(LanguageModelToolResult {
-                tool_use_id: tool_use_id.to_string(),
-                content: tool_output,
-                is_error: false,
-            });
-        }
-
-        self.messages.push(message);
+        });
     }
 
     pub fn to_completion_request(
@@ -98,10 +101,12 @@ impl Thread {
                 cache: false,
             };
 
-            for tool_result in &message.tool_results {
-                request_message
-                    .content
-                    .push(MessageContent::ToolResult(tool_result.clone()));
+            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
+                for tool_result in tool_results {
+                    request_message
+                        .content
+                        .push(MessageContent::ToolResult(tool_result.clone()));
+                }
             }
 
             if !message.text.is_empty() {
@@ -110,10 +115,12 @@ impl Thread {
                     .push(MessageContent::Text(message.text.clone()));
             }
 
-            for tool_use in &message.tool_uses {
-                request_message
-                    .content
-                    .push(MessageContent::ToolUse(tool_use.clone()));
+            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
+                for tool_use in tool_uses {
+                    request_message
+                        .content
+                        .push(MessageContent::ToolUse(tool_use.clone()));
+                }
             }
 
             request.messages.push(request_message);
@@ -143,10 +150,9 @@ impl Thread {
                         match event {
                             LanguageModelCompletionEvent::StartMessage { .. } => {
                                 thread.messages.push(Message {
+                                    id: thread.next_message_id.post_inc(),
                                     role: Role::Assistant,
                                     text: String::new(),
-                                    tool_uses: Vec::new(),
-                                    tool_results: Vec::new(),
                                 });
                             }
                             LanguageModelCompletionEvent::Stop(reason) => {
@@ -160,22 +166,28 @@ impl Thread {
                                 }
                             }
                             LanguageModelCompletionEvent::ToolUse(tool_use) => {
-                                if let Some(last_message) = thread.messages.last_mut() {
-                                    if last_message.role == Role::Assistant {
-                                        last_message.tool_uses.push(tool_use.clone());
-                                    }
+                                if let Some(last_assistant_message) = thread
+                                    .messages
+                                    .iter()
+                                    .rfind(|message| message.role == Role::Assistant)
+                                {
+                                    thread
+                                        .tool_uses_by_message
+                                        .entry(last_assistant_message.id)
+                                        .or_default()
+                                        .push(tool_use.clone());
+
+                                    thread.pending_tool_uses_by_id.insert(
+                                        tool_use.id.clone(),
+                                        PendingToolUse {
+                                            assistant_message_id: last_assistant_message.id,
+                                            id: tool_use.id,
+                                            name: tool_use.name,
+                                            input: tool_use.input,
+                                            status: PendingToolUseStatus::Idle,
+                                        },
+                                    );
                                 }
-
-                                let tool_use_id: Arc<str> = tool_use.id.into();
-                                thread.pending_tool_uses_by_id.insert(
-                                    tool_use_id.clone(),
-                                    PendingToolUse {
-                                        id: tool_use_id,
-                                        name: tool_use.name,
-                                        input: tool_use.input,
-                                        status: PendingToolUseStatus::Idle,
-                                    },
-                                );
                             }
                         }
 
@@ -235,7 +247,8 @@ impl Thread {
 
     pub fn insert_tool_output(
         &mut self,
-        tool_use_id: Arc<str>,
+        assistant_message_id: MessageId,
+        tool_use_id: LanguageModelToolUseId,
         output: Task<Result<String>>,
         cx: &mut ModelContext<Self>,
     ) {
@@ -244,19 +257,39 @@ impl Thread {
             async move {
                 let output = output.await;
                 thread
-                    .update(&mut cx, |thread, cx| match output {
-                        Ok(output) => {
-                            thread
-                                .completed_tool_uses_by_id
-                                .insert(tool_use_id.clone(), output);
+                    .update(&mut cx, |thread, cx| {
+                        // The tool use was requested by an Assistant message,
+                        // so we want to attach the tool results to the next
+                        // user message.
+                        let next_user_message = MessageId(assistant_message_id.0 + 1);
+
+                        let tool_results = thread
+                            .tool_results_by_message
+                            .entry(next_user_message)
+                            .or_default();
+
+                        match output {
+                            Ok(output) => {
+                                tool_results.push(LanguageModelToolResult {
+                                    tool_use_id: tool_use_id.to_string(),
+                                    content: output,
+                                    is_error: false,
+                                });
 
-                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
-                        }
-                        Err(err) => {
-                            if let Some(tool_use) =
-                                thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
-                            {
-                                tool_use.status = PendingToolUseStatus::Error(err.to_string());
+                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                            }
+                            Err(err) => {
+                                tool_results.push(LanguageModelToolResult {
+                                    tool_use_id: tool_use_id.to_string(),
+                                    content: err.to_string(),
+                                    is_error: true,
+                                });
+
+                                if let Some(tool_use) =
+                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+                                {
+                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
+                                }
                             }
                         }
                     })
@@ -278,7 +311,7 @@ pub enum ThreadEvent {
     UsePendingTools,
     ToolFinished {
         #[allow(unused)]
-        tool_use_id: Arc<str>,
+        tool_use_id: LanguageModelToolUseId,
     },
 }
 
@@ -291,7 +324,9 @@ struct PendingCompletion {
 
 #[derive(Debug, Clone)]
 pub struct PendingToolUse {
-    pub id: Arc<str>,
+    pub id: LanguageModelToolUseId,
+    /// The ID of the Assistant message in which the tool use was requested.
+    pub assistant_message_id: MessageId,
     pub name: String,
     pub input: serde_json::Value,
     pub status: PendingToolUseStatus,

crates/language_model/src/language_model.rs 🔗

@@ -63,9 +63,27 @@ pub enum StopReason {
     ToolUse,
 }
 
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelToolUseId(Arc<str>);
+
+impl fmt::Display for LanguageModelToolUseId {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
+impl<T> From<T> for LanguageModelToolUseId
+where
+    T: Into<Arc<str>>,
+{
+    fn from(value: T) -> Self {
+        Self(value.into())
+    }
+}
+
 #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 pub struct LanguageModelToolUse {
-    pub id: String,
+    pub id: LanguageModelToolUseId,
     pub name: String,
     pub input: serde_json::Value,
 }

crates/language_model/src/request.rs 🔗

@@ -347,7 +347,7 @@ impl LanguageModelRequest {
                             }
                             MessageContent::ToolUse(tool_use) => {
                                 Some(anthropic::RequestContent::ToolUse {
-                                    id: tool_use.id,
+                                    id: tool_use.id.to_string(),
                                     name: tool_use.name,
                                     input: tool_use.input,
                                     cache_control,

crates/language_models/src/provider/anthropic.rs 🔗

@@ -498,7 +498,7 @@ pub fn map_to_language_model_completion_events(
                                     Some(maybe!({
                                         Ok(LanguageModelCompletionEvent::ToolUse(
                                             LanguageModelToolUse {
-                                                id: tool_use.id,
+                                                id: tool_use.id.into(),
                                                 name: tool_use.name,
                                                 input: if tool_use.input_json.is_empty() {
                                                     serde_json::Value::Null