@@ -255,12 +255,7 @@ impl ActiveThread {
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
self.thread.update(cx, |thread, cx| {
- thread.insert_tool_output(
- tool_use.assistant_message_id,
- tool_use.id.clone(),
- task,
- cx,
- );
+ thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
@@ -88,8 +88,9 @@ pub struct Thread {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
- tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
- tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
+ tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+ tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
+ tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
}
@@ -107,8 +108,9 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_uses_by_message: HashMap::default(),
- tool_results_by_message: HashMap::default(),
+ tool_uses_by_assistant_message: HashMap::default(),
+ tool_uses_by_user_message: HashMap::default(),
+ tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
}
}
@@ -141,8 +143,9 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_uses_by_message: HashMap::default(),
- tool_results_by_message: HashMap::default(),
+ tool_uses_by_assistant_message: HashMap::default(),
+ tool_uses_by_user_message: HashMap::default(),
+ tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
}
}
@@ -209,26 +212,14 @@ impl Thread {
}
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
- let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else {
+ let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
- // The tool use was requested by an Assistant message, so we need to
- // look for the tool results on the next user message.
- let next_user_message = MessageId(id.0 + 1);
-
- let empty = Vec::new();
- let tool_results_for_message = self
- .tool_results_by_message
- .get(&next_user_message)
- .unwrap_or_else(|| &empty);
-
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
- let tool_result = tool_results_for_message
- .iter()
- .find(|tool_result| tool_result.tool_use_id == tool_use.id);
+ let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
@@ -264,7 +255,7 @@ impl Thread {
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
- self.tool_results_by_message
+ self.tool_uses_by_user_message
.get(&message_id)
.map_or(false, |results| !results.is_empty())
}
@@ -369,13 +360,15 @@ impl Thread {
content: Vec::new(),
cache: false,
};
- if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
+ if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
match request_kind {
RequestKind::Chat => {
- for tool_result in tool_results {
- request_message
- .content
- .push(MessageContent::ToolResult(tool_result.clone()));
+ for tool_use_id in tool_uses {
+ if let Some(tool_result) = self.tool_results.get(tool_use_id) {
+ request_message
+ .content
+ .push(MessageContent::ToolResult(tool_result.clone()));
+ }
}
}
RequestKind::Summarize => {
@@ -390,7 +383,7 @@ impl Thread {
.push(MessageContent::Text(message.text.clone()));
}
- if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
+ if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
match request_kind {
RequestKind::Chat => {
for tool_use in tool_uses {
@@ -477,11 +470,22 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant)
{
thread
- .tool_uses_by_message
+ .tool_uses_by_assistant_message
.entry(last_assistant_message.id)
.or_default()
.push(tool_use.clone());
+ // The tool use is being requested by the
+ // Assistant, so we want to attach the tool
+ // results to the next user message.
+ let next_user_message_id =
+ MessageId(last_assistant_message.id.0 + 1);
+ thread
+ .tool_uses_by_user_message
+ .entry(next_user_message_id)
+ .or_default()
+ .push(tool_use.id.clone());
+
thread.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
@@ -611,7 +615,6 @@ impl Thread {
pub fn insert_tool_output(
&mut self,
- assistant_message_id: MessageId,
tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>,
cx: &mut Context<Self>,
@@ -621,44 +624,38 @@ impl Thread {
async move {
let output = output.await;
thread
- .update(&mut cx, |thread, cx| {
- // The tool use was requested by an Assistant message,
- // so we want to attach the tool results to the next
- // user message.
- let next_user_message = MessageId(assistant_message_id.0 + 1);
-
- let tool_results = thread
- .tool_results_by_message
- .entry(next_user_message)
- .or_default();
-
- match output {
- Ok(output) => {
- tool_results.push(LanguageModelToolResult {
+ .update(&mut cx, |thread, cx| match output {
+ Ok(output) => {
+ thread.tool_results.insert(
+ tool_use_id.clone(),
+ LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: output.into(),
is_error: false,
- });
- thread.pending_tool_uses_by_id.remove(&tool_use_id);
+ },
+ );
+ thread.pending_tool_uses_by_id.remove(&tool_use_id);
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
- }
- Err(err) => {
- tool_results.push(LanguageModelToolResult {
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ }
+ Err(err) => {
+ thread.tool_results.insert(
+ tool_use_id.clone(),
+ LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: err.to_string().into(),
is_error: true,
- });
-
- if let Some(tool_use) =
- thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
- {
- tool_use.status =
- PendingToolUseStatus::Error(err.to_string().into());
- }
-
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ },
+ );
+
+ if let Some(tool_use) =
+ thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+ {
+ tool_use.status =
+ PendingToolUseStatus::Error(err.to_string().into());
}
+
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
})
.ok();