Cargo.lock 🔗
@@ -465,6 +465,7 @@ dependencies = [
"language_model",
"language_model_selector",
"proto",
+ "serde",
"serde_json",
"settings",
"smol",
Marshall Bowers created
This PR restructures the storage of the tool uses and results in
`assistant2` so that they don't live on the individual messages.
It also introduces a `LanguageModelToolUseId` newtype for better type
safety.
Release Notes:
- N/A
Cargo.lock | 1
crates/assistant/src/assistant_panel.rs | 2
crates/assistant/src/context.rs | 21 +-
crates/assistant2/Cargo.toml | 1
crates/assistant2/src/assistant_panel.rs | 7
crates/assistant2/src/thread.rs | 157 +++++++++++------
crates/language_model/src/language_model.rs | 20 ++
crates/language_model/src/request.rs | 2
crates/language_models/src/provider/anthropic.rs | 2
9 files changed, 136 insertions(+), 77 deletions(-)
@@ -465,6 +465,7 @@ dependencies = [
"language_model",
"language_model_selector",
"proto",
+ "serde",
"serde_json",
"settings",
"smol",
@@ -1925,7 +1925,7 @@ impl ContextEditor {
Content::ToolUse {
range: tool_use.source_range.clone(),
tool_use: LanguageModelToolUse {
- id: tool_use.id.to_string(),
+ id: tool_use.id.clone(),
name: tool_use.name.clone(),
input: tool_use.input.clone(),
},
@@ -27,8 +27,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
- StopReason,
+ LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
+ LanguageModelToolUseId, MessageContent, Role, StopReason,
};
use language_models::{
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
@@ -385,7 +385,7 @@ pub enum ContextEvent {
},
UsePendingTools,
ToolFinished {
- tool_use_id: Arc<str>,
+ tool_use_id: LanguageModelToolUseId,
output_range: Range<language::Anchor>,
},
Operation(ContextOperation),
@@ -479,7 +479,7 @@ pub enum Content {
},
ToolResult {
range: Range<language::Anchor>,
- tool_use_id: Arc<str>,
+ tool_use_id: LanguageModelToolUseId,
},
}
@@ -546,7 +546,7 @@ pub struct Context {
pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
pub(crate) tools: Arc<ToolWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
- pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
+ pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>,
@@ -1126,7 +1126,7 @@ impl Context {
self.pending_tool_uses_by_id.values().collect()
}
- pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
+ pub fn get_tool_use_by_id(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
self.pending_tool_uses_by_id.get(id)
}
@@ -2153,7 +2153,7 @@ impl Context {
pub fn insert_tool_output(
&mut self,
- tool_use_id: Arc<str>,
+ tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>,
cx: &mut ModelContext<Self>,
) {
@@ -2340,11 +2340,10 @@ impl Context {
let source_range = buffer.anchor_after(start_ix)
..buffer.anchor_after(end_ix);
- let tool_use_id: Arc<str> = tool_use.id.into();
this.pending_tool_uses_by_id.insert(
- tool_use_id.clone(),
+ tool_use.id.clone(),
PendingToolUse {
- id: tool_use_id,
+ id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
@@ -3203,7 +3202,7 @@ pub enum PendingSlashCommandStatus {
#[derive(Debug, Clone)]
pub struct PendingToolUse {
- pub id: Arc<str>,
+ pub id: LanguageModelToolUseId,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
@@ -25,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true
proto.workspace = true
settings.workspace = true
+serde.workspace = true
serde_json.workspace = true
smol.workspace = true
theme.workspace = true
@@ -102,7 +102,12 @@ impl AssistantPanel {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| {
- thread.insert_tool_output(tool_use.id.clone(), task, cx);
+ thread.insert_tool_output(
+ tool_use.assistant_message_id,
+ tool_use.id.clone(),
+ task,
+ cx,
+ );
});
}
}
@@ -8,8 +8,10 @@ use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
+ LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
+ StopReason,
};
+use serde::{Deserialize, Serialize};
use util::post_inc;
#[derive(Debug, Clone, Copy)]
@@ -17,34 +19,46 @@ pub enum RequestKind {
Chat,
}
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+pub struct MessageId(usize);
+
+impl MessageId {
+ fn post_inc(&mut self) -> Self {
+ Self(post_inc(&mut self.0))
+ }
+}
+
/// A message in a [`Thread`].
#[derive(Debug, Clone)]
pub struct Message {
+ pub id: MessageId,
pub role: Role,
pub text: String,
- pub tool_uses: Vec<LanguageModelToolUse>,
- pub tool_results: Vec<LanguageModelToolResult>,
}
/// A thread of conversation with the LLM.
pub struct Thread {
messages: Vec<Message>,
+ next_message_id: MessageId,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
- pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
- completed_tool_uses_by_id: HashMap<Arc<str>, String>,
+ tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+ tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
+ pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
}
impl Thread {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self {
- tools,
messages: Vec::new(),
+ next_message_id: MessageId(0),
completion_count: 0,
pending_completions: Vec::new(),
+ tools,
+ tool_uses_by_message: HashMap::default(),
+ tool_results_by_message: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
- completed_tool_uses_by_id: HashMap::default(),
}
}
@@ -61,22 +75,11 @@ impl Thread {
}
pub fn insert_user_message(&mut self, text: impl Into<String>) {
- let mut message = Message {
+ self.messages.push(Message {
+ id: self.next_message_id.post_inc(),
role: Role::User,
text: text.into(),
- tool_uses: Vec::new(),
- tool_results: Vec::new(),
- };
-
- for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
- message.tool_results.push(LanguageModelToolResult {
- tool_use_id: tool_use_id.to_string(),
- content: tool_output,
- is_error: false,
- });
- }
-
- self.messages.push(message);
+ });
}
pub fn to_completion_request(
@@ -98,10 +101,12 @@ impl Thread {
cache: false,
};
- for tool_result in &message.tool_results {
- request_message
- .content
- .push(MessageContent::ToolResult(tool_result.clone()));
+ if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
+ for tool_result in tool_results {
+ request_message
+ .content
+ .push(MessageContent::ToolResult(tool_result.clone()));
+ }
}
if !message.text.is_empty() {
@@ -110,10 +115,12 @@ impl Thread {
.push(MessageContent::Text(message.text.clone()));
}
- for tool_use in &message.tool_uses {
- request_message
- .content
- .push(MessageContent::ToolUse(tool_use.clone()));
+ if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
+ for tool_use in tool_uses {
+ request_message
+ .content
+ .push(MessageContent::ToolUse(tool_use.clone()));
+ }
}
request.messages.push(request_message);
@@ -143,10 +150,9 @@ impl Thread {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message {
+ id: thread.next_message_id.post_inc(),
role: Role::Assistant,
text: String::new(),
- tool_uses: Vec::new(),
- tool_results: Vec::new(),
});
}
LanguageModelCompletionEvent::Stop(reason) => {
@@ -160,22 +166,28 @@ impl Thread {
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
- if let Some(last_message) = thread.messages.last_mut() {
- if last_message.role == Role::Assistant {
- last_message.tool_uses.push(tool_use.clone());
- }
+ if let Some(last_assistant_message) = thread
+ .messages
+ .iter()
+ .rfind(|message| message.role == Role::Assistant)
+ {
+ thread
+ .tool_uses_by_message
+ .entry(last_assistant_message.id)
+ .or_default()
+ .push(tool_use.clone());
+
+ thread.pending_tool_uses_by_id.insert(
+ tool_use.id.clone(),
+ PendingToolUse {
+ assistant_message_id: last_assistant_message.id,
+ id: tool_use.id,
+ name: tool_use.name,
+ input: tool_use.input,
+ status: PendingToolUseStatus::Idle,
+ },
+ );
}
-
- let tool_use_id: Arc<str> = tool_use.id.into();
- thread.pending_tool_uses_by_id.insert(
- tool_use_id.clone(),
- PendingToolUse {
- id: tool_use_id,
- name: tool_use.name,
- input: tool_use.input,
- status: PendingToolUseStatus::Idle,
- },
- );
}
}
@@ -235,7 +247,8 @@ impl Thread {
pub fn insert_tool_output(
&mut self,
- tool_use_id: Arc<str>,
+ assistant_message_id: MessageId,
+ tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>,
cx: &mut ModelContext<Self>,
) {
@@ -244,19 +257,39 @@ impl Thread {
async move {
let output = output.await;
thread
- .update(&mut cx, |thread, cx| match output {
- Ok(output) => {
- thread
- .completed_tool_uses_by_id
- .insert(tool_use_id.clone(), output);
+ .update(&mut cx, |thread, cx| {
+ // The tool use was requested by an Assistant message,
+ // so we want to attach the tool results to the next
+ // user message.
+ let next_user_message = MessageId(assistant_message_id.0 + 1);
+
+ let tool_results = thread
+ .tool_results_by_message
+ .entry(next_user_message)
+ .or_default();
+
+ match output {
+ Ok(output) => {
+ tool_results.push(LanguageModelToolResult {
+ tool_use_id: tool_use_id.to_string(),
+ content: output,
+ is_error: false,
+ });
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
- }
- Err(err) => {
- if let Some(tool_use) =
- thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
- {
- tool_use.status = PendingToolUseStatus::Error(err.to_string());
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ }
+ Err(err) => {
+ tool_results.push(LanguageModelToolResult {
+ tool_use_id: tool_use_id.to_string(),
+ content: err.to_string(),
+ is_error: true,
+ });
+
+ if let Some(tool_use) =
+ thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+ {
+ tool_use.status = PendingToolUseStatus::Error(err.to_string());
+ }
}
}
})
@@ -278,7 +311,7 @@ pub enum ThreadEvent {
UsePendingTools,
ToolFinished {
#[allow(unused)]
- tool_use_id: Arc<str>,
+ tool_use_id: LanguageModelToolUseId,
},
}
@@ -291,7 +324,9 @@ struct PendingCompletion {
#[derive(Debug, Clone)]
pub struct PendingToolUse {
- pub id: Arc<str>,
+ pub id: LanguageModelToolUseId,
+ /// The ID of the Assistant message in which the tool use was requested.
+ pub assistant_message_id: MessageId,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
@@ -63,9 +63,27 @@ pub enum StopReason {
ToolUse,
}
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelToolUseId(Arc<str>);
+
+impl fmt::Display for LanguageModelToolUseId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+impl<T> From<T> for LanguageModelToolUseId
+where
+ T: Into<Arc<str>>,
+{
+ fn from(value: T) -> Self {
+ Self(value.into())
+ }
+}
+
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
- pub id: String,
+ pub id: LanguageModelToolUseId,
pub name: String,
pub input: serde_json::Value,
}
@@ -347,7 +347,7 @@ impl LanguageModelRequest {
}
MessageContent::ToolUse(tool_use) => {
Some(anthropic::RequestContent::ToolUse {
- id: tool_use.id,
+ id: tool_use.id.to_string(),
name: tool_use.name,
input: tool_use.input,
cache_control,
@@ -498,7 +498,7 @@ pub fn map_to_language_model_completion_events(
Some(maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
- id: tool_use.id,
+ id: tool_use.id.into(),
name: tool_use.name,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Null