@@ -4,14 +4,12 @@ use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
-use futures::future::Shared;
-use futures::{FutureExt as _, StreamExt as _};
+use futures::StreamExt as _;
use gpui::{App, Context, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolUse, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
- PaymentRequiredError, Role, StopReason,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
+ MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
};
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
@@ -19,6 +17,7 @@ use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::SavedThread;
+use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -43,7 +42,7 @@ impl std::fmt::Display for ThreadId {
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-pub struct MessageId(usize);
+pub struct MessageId(pub(crate) usize);
impl MessageId {
fn post_inc(&mut self) -> Self {
@@ -59,22 +58,6 @@ pub struct Message {
pub text: String,
}
-#[derive(Debug)]
-pub struct ToolUse {
- pub id: LanguageModelToolUseId,
- pub name: SharedString,
- pub status: ToolUseStatus,
- pub input: serde_json::Value,
-}
-
-#[derive(Debug, Clone)]
-pub enum ToolUseStatus {
- Pending,
- Running,
- Finished(SharedString),
- Error(SharedString),
-}
-
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -88,10 +71,7 @@ pub struct Thread {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
- tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
- tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
- tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
- pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
+ tool_use: ToolUseState,
}
impl Thread {
@@ -108,10 +88,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_uses_by_assistant_message: HashMap::default(),
- tool_uses_by_user_message: HashMap::default(),
- tool_results: HashMap::default(),
- pending_tool_uses_by_id: HashMap::default(),
+ tool_use: ToolUseState::default(),
}
}
@@ -143,10 +120,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
- tool_uses_by_assistant_message: HashMap::default(),
- tool_uses_by_user_message: HashMap::default(),
- tool_results: HashMap::default(),
- pending_tool_uses_by_id: HashMap::default(),
+ tool_use: ToolUseState::default(),
}
}
@@ -208,56 +182,15 @@ impl Thread {
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
- self.pending_tool_uses_by_id.values().collect()
+ self.tool_use.pending_tool_uses()
}
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
- let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
- return Vec::new();
- };
-
- let mut tool_uses = Vec::new();
-
- for tool_use in tool_uses_for_message.iter() {
- let tool_result = self.tool_results.get(&tool_use.id);
-
- let status = (|| {
- if let Some(tool_result) = tool_result {
- return if tool_result.is_error {
- ToolUseStatus::Error(tool_result.content.clone().into())
- } else {
- ToolUseStatus::Finished(tool_result.content.clone().into())
- };
- }
-
- if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
- return match pending_tool_use.status {
- PendingToolUseStatus::Idle => ToolUseStatus::Pending,
- PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
- PendingToolUseStatus::Error(ref err) => {
- ToolUseStatus::Error(err.clone().into())
- }
- };
- }
-
- ToolUseStatus::Pending
- })();
-
- tool_uses.push(ToolUse {
- id: tool_use.id.clone(),
- name: tool_use.name.clone().into(),
- input: tool_use.input.clone(),
- status,
- })
- }
-
- tool_uses
+ self.tool_use.tool_uses_for_message(id)
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
- self.tool_uses_by_user_message
- .get(&message_id)
- .map_or(false, |results| !results.is_empty())
+ self.tool_use.message_has_tool_results(message_id)
}
pub fn insert_user_message(
@@ -360,20 +293,13 @@ impl Thread {
content: Vec::new(),
cache: false,
};
- if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
- match request_kind {
- RequestKind::Chat => {
- for tool_use_id in tool_uses {
- if let Some(tool_result) = self.tool_results.get(tool_use_id) {
- request_message
- .content
- .push(MessageContent::ToolResult(tool_result.clone()));
- }
- }
- }
- RequestKind::Summarize => {
- // We don't care about tool use during summarization.
- }
+ match request_kind {
+ RequestKind::Chat => {
+ self.tool_use
+ .attach_tool_results(message.id, &mut request_message);
+ }
+ RequestKind::Summarize => {
+ // We don't care about tool use during summarization.
}
}
@@ -383,18 +309,13 @@ impl Thread {
.push(MessageContent::Text(message.text.clone()));
}
- if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
- match request_kind {
- RequestKind::Chat => {
- for tool_use in tool_uses {
- request_message
- .content
- .push(MessageContent::ToolUse(tool_use.clone()));
- }
- }
- RequestKind::Summarize => {
- // We don't care about tool use during summarization.
- }
+ match request_kind {
+ RequestKind::Chat => {
+ self.tool_use
+ .attach_tool_uses(message.id, &mut request_message);
+ }
+ RequestKind::Summarize => {
+ // We don't care about tool use during summarization.
}
}
@@ -470,32 +391,8 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant)
{
thread
- .tool_uses_by_assistant_message
- .entry(last_assistant_message.id)
- .or_default()
- .push(tool_use.clone());
-
- // The tool use is being requested by the
- // Assistant, so we want to attach the tool
- // results to the next user message.
- let next_user_message_id =
- MessageId(last_assistant_message.id.0 + 1);
- thread
- .tool_uses_by_user_message
- .entry(next_user_message_id)
- .or_default()
- .push(tool_use.id.clone());
-
- thread.pending_tool_uses_by_id.insert(
- tool_use.id.clone(),
- PendingToolUse {
- assistant_message_id: last_assistant_message.id,
- id: tool_use.id,
- name: tool_use.name,
- input: tool_use.input,
- status: PendingToolUseStatus::Idle,
- },
- );
+ .tool_use
+ .request_tool_use(last_assistant_message.id, tool_use);
}
}
}
@@ -624,49 +521,19 @@ impl Thread {
async move {
let output = output.await;
thread
- .update(&mut cx, |thread, cx| match output {
- Ok(output) => {
- thread.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- content: output.into(),
- is_error: false,
- },
- );
- thread.pending_tool_uses_by_id.remove(&tool_use_id);
-
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
- }
- Err(err) => {
- thread.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- content: err.to_string().into(),
- is_error: true,
- },
- );
-
- if let Some(tool_use) =
- thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
- {
- tool_use.status =
- PendingToolUseStatus::Error(err.to_string().into());
- }
+ .update(&mut cx, |thread, cx| {
+ thread
+ .tool_use
+ .insert_tool_output(tool_use_id.clone(), output);
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
- }
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
})
.ok();
}
});
- if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
- tool_use.status = PendingToolUseStatus::Running {
- _task: insert_output_task.shared(),
- };
- }
+ self.tool_use
+ .run_pending_tool(tool_use_id, insert_output_task);
}
/// Cancels the last pending completion, if there are any pending.
@@ -708,30 +575,3 @@ struct PendingCompletion {
id: usize,
_task: Task<()>,
}
-
-#[derive(Debug, Clone)]
-pub struct PendingToolUse {
- pub id: LanguageModelToolUseId,
- /// The ID of the Assistant message in which the tool use was requested.
- pub assistant_message_id: MessageId,
- pub name: Arc<str>,
- pub input: serde_json::Value,
- pub status: PendingToolUseStatus,
-}
-
-#[derive(Debug, Clone)]
-pub enum PendingToolUseStatus {
- Idle,
- Running { _task: Shared<Task<()>> },
- Error(#[allow(unused)] Arc<str>),
-}
-
-impl PendingToolUseStatus {
- pub fn is_idle(&self) -> bool {
- matches!(self, PendingToolUseStatus::Idle)
- }
-
- pub fn is_error(&self) -> bool {
- matches!(self, PendingToolUseStatus::Error(_))
- }
-}
@@ -0,0 +1,221 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use collections::HashMap;
+use futures::future::Shared;
+use futures::FutureExt as _;
+use gpui::{SharedString, Task};
+use language_model::{
+ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
+ LanguageModelToolUseId, MessageContent,
+};
+
+use crate::thread::MessageId;
+
+#[derive(Debug)]
+pub struct ToolUse {
+ pub id: LanguageModelToolUseId,
+ pub name: SharedString,
+ pub status: ToolUseStatus,
+ pub input: serde_json::Value,
+}
+
+#[derive(Debug, Clone)]
+pub enum ToolUseStatus {
+ Pending,
+ Running,
+ Finished(SharedString),
+ Error(SharedString),
+}
+
+#[derive(Default)]
+pub struct ToolUseState {
+ tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+ tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
+ tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
+ pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
+}
+
+impl ToolUseState {
+ pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
+ self.pending_tool_uses_by_id.values().collect()
+ }
+
+ pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
+ let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
+ return Vec::new();
+ };
+
+ let mut tool_uses = Vec::new();
+
+ for tool_use in tool_uses_for_message.iter() {
+ let tool_result = self.tool_results.get(&tool_use.id);
+
+ let status = (|| {
+ if let Some(tool_result) = tool_result {
+ return if tool_result.is_error {
+ ToolUseStatus::Error(tool_result.content.clone().into())
+ } else {
+ ToolUseStatus::Finished(tool_result.content.clone().into())
+ };
+ }
+
+ if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
+ return match pending_tool_use.status {
+ PendingToolUseStatus::Idle => ToolUseStatus::Pending,
+ PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
+ PendingToolUseStatus::Error(ref err) => {
+ ToolUseStatus::Error(err.clone().into())
+ }
+ };
+ }
+
+ ToolUseStatus::Pending
+ })();
+
+ tool_uses.push(ToolUse {
+ id: tool_use.id.clone(),
+ name: tool_use.name.clone().into(),
+ input: tool_use.input.clone(),
+ status,
+ })
+ }
+
+ tool_uses
+ }
+
+ pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
+ self.tool_uses_by_user_message
+ .get(&message_id)
+ .map_or(false, |results| !results.is_empty())
+ }
+
+ pub fn request_tool_use(
+ &mut self,
+ assistant_message_id: MessageId,
+ tool_use: LanguageModelToolUse,
+ ) {
+ self.tool_uses_by_assistant_message
+ .entry(assistant_message_id)
+ .or_default()
+ .push(tool_use.clone());
+
+ // The tool use is being requested by the Assistant, so we want to
+ // attach the tool results to the next user message.
+ let next_user_message_id = MessageId(assistant_message_id.0 + 1);
+ self.tool_uses_by_user_message
+ .entry(next_user_message_id)
+ .or_default()
+ .push(tool_use.id.clone());
+
+ self.pending_tool_uses_by_id.insert(
+ tool_use.id.clone(),
+ PendingToolUse {
+ assistant_message_id,
+ id: tool_use.id,
+ name: tool_use.name,
+ input: tool_use.input,
+ status: PendingToolUseStatus::Idle,
+ },
+ );
+ }
+
+ pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
+ if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+ tool_use.status = PendingToolUseStatus::Running {
+ _task: task.shared(),
+ };
+ }
+ }
+
+ pub fn insert_tool_output(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ output: Result<String>,
+ ) {
+ match output {
+ Ok(output) => {
+ self.tool_results.insert(
+ tool_use_id.clone(),
+ LanguageModelToolResult {
+ tool_use_id: tool_use_id.clone(),
+ content: output.into(),
+ is_error: false,
+ },
+ );
+ self.pending_tool_uses_by_id.remove(&tool_use_id);
+ }
+ Err(err) => {
+ self.tool_results.insert(
+ tool_use_id.clone(),
+ LanguageModelToolResult {
+ tool_use_id: tool_use_id.clone(),
+ content: err.to_string().into(),
+ is_error: true,
+ },
+ );
+
+ if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+ tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
+ }
+ }
+ }
+ }
+
+ pub fn attach_tool_uses(
+ &self,
+ message_id: MessageId,
+ request_message: &mut LanguageModelRequestMessage,
+ ) {
+ if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
+ for tool_use in tool_uses {
+ request_message
+ .content
+ .push(MessageContent::ToolUse(tool_use.clone()));
+ }
+ }
+ }
+
+ pub fn attach_tool_results(
+ &self,
+ message_id: MessageId,
+ request_message: &mut LanguageModelRequestMessage,
+ ) {
+ if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
+ for tool_use_id in tool_uses {
+ if let Some(tool_result) = self.tool_results.get(tool_use_id) {
+ request_message
+ .content
+ .push(MessageContent::ToolResult(tool_result.clone()));
+ }
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct PendingToolUse {
+ pub id: LanguageModelToolUseId,
+ /// The ID of the Assistant message in which the tool use was requested.
+ pub assistant_message_id: MessageId,
+ pub name: Arc<str>,
+ pub input: serde_json::Value,
+ pub status: PendingToolUseStatus,
+}
+
+#[derive(Debug, Clone)]
+pub enum PendingToolUseStatus {
+ Idle,
+ Running { _task: Shared<Task<()>> },
+ Error(#[allow(unused)] Arc<str>),
+}
+
+impl PendingToolUseStatus {
+ pub fn is_idle(&self) -> bool {
+ matches!(self, PendingToolUseStatus::Idle)
+ }
+
+ pub fn is_error(&self) -> bool {
+ matches!(self, PendingToolUseStatus::Error(_))
+ }
+}