@@ -20,6 +20,7 @@ use crate::{
};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
+use assistant_tool::ToolRegistry;
use client::{proto, Client, Status};
use collections::{BTreeSet, HashMap, HashSet};
use editor::{
@@ -2091,6 +2092,27 @@ impl ContextEditor {
}
}
}
+ ContextEvent::UsePendingTools => {
+ let pending_tool_uses = self
+ .context
+ .read(cx)
+ .pending_tool_uses()
+ .into_iter()
+ .filter(|tool_use| tool_use.status.is_idle())
+ .cloned()
+ .collect::<Vec<_>>();
+
+ for tool_use in pending_tool_uses {
+ let tool_registry = ToolRegistry::global(cx);
+ if let Some(tool) = tool_registry.tool(&tool_use.name) {
+ let task = tool.run(tool_use.input, self.workspace.clone(), cx);
+
+ self.context.update(cx, |context, cx| {
+ context.insert_tool_output(tool_use.id.clone(), task, cx);
+ });
+ }
+ }
+ }
ContextEvent::Operation(_) => {}
ContextEvent::ShowAssistError(error_message) => {
self.error_message = Some(error_message.clone());
@@ -29,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelRequestTool, MessageContent, Role,
+ LanguageModelRequestTool, MessageContent, Role, StopReason,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
@@ -306,6 +306,7 @@ pub enum ContextEvent {
run_commands_in_output: bool,
expand_result: bool,
},
+ UsePendingTools,
Operation(ContextOperation),
}
@@ -416,6 +417,7 @@ impl Message {
range_start = *image_offset;
}
+
if range_start != self.offset_range.end {
if let Some(text) =
Self::collect_text_content(buffer, range_start..self.offset_range.end)
@@ -492,7 +494,7 @@ pub struct Context {
edits_since_last_parse: language::Subscription,
finished_slash_commands: HashSet<SlashCommandId>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
- pending_tool_uses_by_id: HashMap<String, PendingToolUse>,
+ pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
image_anchors: Vec<ImageAnchor>,
@@ -1012,7 +1014,7 @@ impl Context {
self.pending_tool_uses_by_id.values().collect()
}
- pub fn get_tool_use_by_id(&self, id: &String) -> Option<&PendingToolUse> {
+ pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
self.pending_tool_uses_by_id.get(id)
}
@@ -1919,6 +1921,45 @@ impl Context {
}
}
+ pub fn insert_tool_output(
+ &mut self,
+ tool_id: Arc<str>,
+ output: Task<Result<String>>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ let insert_output_task = cx.spawn(|this, mut cx| {
+ let tool_id = tool_id.clone();
+ async move {
+ let output = output.await;
+ this.update(&mut cx, |this, cx| match output {
+ Ok(mut output) => {
+ if !output.ends_with('\n') {
+ output.push('\n');
+ }
+
+ this.buffer.update(cx, |buffer, cx| {
+ let buffer_end = buffer.len().to_offset(buffer);
+
+ buffer.edit([(buffer_end..buffer_end, output)], None, cx);
+ });
+ }
+ Err(err) => {
+ if let Some(tool_use) = this.pending_tool_uses_by_id.get_mut(&tool_id) {
+ tool_use.status = PendingToolUseStatus::Error(err.to_string());
+ }
+ }
+ })
+ .ok();
+ }
+ });
+
+ if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_id) {
+ tool_use.status = PendingToolUseStatus::Running {
+ _task: insert_output_task.shared(),
+ };
+ }
+ }
+
pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
self.count_remaining_tokens(cx);
}
@@ -1990,7 +2031,7 @@ impl Context {
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
- this.buffer.update(cx, |buffer, cx| {
+ let event_to_emit = this.buffer.update(cx, |buffer, cx| {
let message_old_end_offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
@@ -2000,9 +2041,11 @@ impl Context {
match event {
LanguageModelCompletionEvent::Stop(reason) => match reason {
- language_model::StopReason::ToolUse => {}
- language_model::StopReason::EndTurn => {}
- language_model::StopReason::MaxTokens => {}
+ StopReason::ToolUse => {
+ return Some(ContextEvent::UsePendingTools);
+ }
+ StopReason::EndTurn => {}
+ StopReason::MaxTokens => {}
},
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit(
@@ -2041,10 +2084,11 @@ impl Context {
let source_range = buffer.anchor_after(start_ix)
..buffer.anchor_after(end_ix);
+ let tool_use_id: Arc<str> = tool_use.id.into();
this.pending_tool_uses_by_id.insert(
- tool_use.id.clone(),
+ tool_use_id.clone(),
PendingToolUse {
- id: tool_use.id,
+ id: tool_use_id,
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
@@ -2053,9 +2097,14 @@ impl Context {
);
}
}
+
+ None
});
cx.emit(ContextEvent::StreamedCompletion);
+ if let Some(event) = event_to_emit {
+ cx.emit(event);
+ }
Some(())
})?;
@@ -2821,7 +2870,7 @@ impl FeatureFlag for ToolUseFeatureFlag {
#[derive(Debug, Clone)]
pub struct PendingToolUse {
- pub id: String,
+ pub id: Arc<str>,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
@@ -2835,6 +2884,12 @@ pub enum PendingToolUseStatus {
Error(String),
}
+impl PendingToolUseStatus {
+ pub fn is_idle(&self) -> bool {
+ matches!(self, PendingToolUseStatus::Idle)
+ }
+}
+
#[derive(Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,