Detailed changes
@@ -490,6 +490,7 @@ dependencies = [
"proto",
"rand 0.8.5",
"rope",
+ "scripting_tool",
"serde",
"serde_json",
"settings",
@@ -11915,7 +11916,6 @@ name = "scripting_tool"
version = "0.1.0"
dependencies = [
"anyhow",
- "assistant_tool",
"collections",
"futures 0.3.31",
"gpui",
@@ -16986,7 +16986,6 @@ dependencies = [
"repl",
"reqwest_client",
"rope",
- "scripting_tool",
"search",
"serde",
"serde_json",
@@ -8,7 +8,6 @@ members = [
"crates/assistant",
"crates/assistant2",
"crates/assistant_context_editor",
- "crates/scripting_tool",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -119,6 +118,7 @@ members = [
"crates/rope",
"crates/rpc",
"crates/schema_generator",
+ "crates/scripting_tool",
"crates/search",
"crates/semantic_index",
"crates/semantic_version",
@@ -59,6 +59,7 @@ prompt_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
rope.workspace = true
+scripting_tool.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -457,9 +457,13 @@ impl ActiveThread {
let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id);
+ let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results.
- if message.role == Role::User && thread.message_has_tool_results(message_id) {
+ if message.role == Role::User
+ && (thread.message_has_tool_results(message_id)
+ || thread.message_has_scripting_tool_results(message_id))
+ {
return Empty.into_any();
}
@@ -609,16 +613,22 @@ impl ActiveThread {
.id(("message-container", ix))
.child(message_content)
.map(|parent| {
- if tool_uses.is_empty() {
+ if tool_uses.is_empty() && scripting_tool_uses.is_empty() {
return parent;
}
parent.child(
- v_flex().children(
- tool_uses
- .into_iter()
- .map(|tool_use| self.render_tool_use(tool_use, cx)),
- ),
+ v_flex()
+ .children(
+ tool_uses
+ .into_iter()
+ .map(|tool_use| self.render_tool_use(tool_use, cx)),
+ )
+ .children(
+ scripting_tool_uses
+ .into_iter()
+ .map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
+ ),
)
}),
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
@@ -727,6 +737,15 @@ impl ActiveThread {
}),
)
}
+
+ fn render_scripting_tool_use(
+ &self,
+ tool_use: ToolUse,
+ cx: &mut Context<Self>,
+ ) -> impl IntoElement {
+ // TODO: Add custom rendering for scripting tool uses.
+ self.render_tool_use(tool_use, cx)
+ }
}
impl Render for ActiveThread {
@@ -13,13 +13,14 @@ use language_model::{
Role, StopReason,
};
use project::Project;
+use scripting_tool::ScriptingTool;
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::SavedThread;
-use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
+use crate::tool_use::{ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -75,6 +76,7 @@ pub struct Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
+ scripting_tool_use: ToolUseState,
}
impl Thread {
@@ -97,6 +99,7 @@ impl Thread {
project,
tools,
tool_use: ToolUseState::new(),
+ scripting_tool_use: ToolUseState::new(),
}
}
@@ -115,6 +118,7 @@ impl Thread {
.unwrap_or(0),
);
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
+ let scripting_tool_use = ToolUseState::new();
Self {
id,
@@ -138,6 +142,7 @@ impl Thread {
project,
tools,
tool_use,
+ scripting_tool_use,
}
}
@@ -198,31 +203,46 @@ impl Thread {
)
}
- pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
- self.tool_use.pending_tool_uses()
- }
-
/// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool {
+ let mut all_pending_tool_uses = self
+ .tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .chain(self.scripting_tool_use.pending_tool_uses());
+
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all
// of the pending tools.
- self.pending_tool_uses()
- .into_iter()
- .all(|tool_use| tool_use.status.is_error())
+ all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
}
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id)
}
+ pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
+ self.scripting_tool_use.tool_uses_for_message(id)
+ }
+
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(id)
}
+ pub fn scripting_tool_results_for_message(
+ &self,
+ id: MessageId,
+ ) -> Vec<&LanguageModelToolResult> {
+ self.scripting_tool_use.tool_results_for_message(id)
+ }
+
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
+ pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
+ self.scripting_tool_use.message_has_tool_results(message_id)
+ }
+
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@@ -313,16 +333,25 @@ impl Thread {
let mut request = self.to_completion_request(request_kind, cx);
if use_tools {
- request.tools = self
- .tools()
- .tools(cx)
- .into_iter()
- .map(|tool| LanguageModelRequestTool {
- name: tool.name(),
- description: tool.description(),
- input_schema: tool.input_schema(),
- })
- .collect();
+ let mut tools = Vec::new();
+ tools.push(LanguageModelRequestTool {
+ name: ScriptingTool::NAME.into(),
+ description: ScriptingTool::DESCRIPTION.into(),
+ input_schema: ScriptingTool::input_schema(),
+ });
+
+ tools.extend(
+ self.tools()
+ .tools(cx)
+ .into_iter()
+ .map(|tool| LanguageModelRequestTool {
+ name: tool.name(),
+ description: tool.description(),
+ input_schema: tool.input_schema(),
+ }),
+ );
+
+ request.tools = tools;
}
self.stream_completion(request, model, cx);
@@ -357,6 +386,8 @@ impl Thread {
RequestKind::Chat => {
self.tool_use
.attach_tool_results(message.id, &mut request_message);
+ self.scripting_tool_use
+ .attach_tool_results(message.id, &mut request_message);
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
@@ -373,6 +404,8 @@ impl Thread {
RequestKind::Chat => {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
+ self.scripting_tool_use
+ .attach_tool_uses(message.id, &mut request_message);
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
@@ -450,9 +483,15 @@ impl Thread {
.iter()
.rfind(|message| message.role == Role::Assistant)
{
- thread
- .tool_use
- .request_tool_use(last_assistant_message.id, tool_use);
+ if tool_use.name.as_ref() == ScriptingTool::NAME {
+ thread
+ .scripting_tool_use
+ .request_tool_use(last_assistant_message.id, tool_use);
+ } else {
+ thread
+ .tool_use
+ .request_tool_use(last_assistant_message.id, tool_use);
+ }
}
}
}
@@ -572,6 +611,7 @@ impl Thread {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
let pending_tool_uses = self
+ .tool_use
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
@@ -585,6 +625,20 @@ impl Thread {
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
}
+
+ let pending_scripting_tool_uses = self
+ .scripting_tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .filter(|tool_use| tool_use.status.is_idle())
+ .cloned()
+ .collect::<Vec<_>>();
+
+ for scripting_tool_use in pending_scripting_tool_uses {
+ let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx);
+
+ self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
+ }
}
pub fn insert_tool_output(
@@ -613,6 +667,32 @@ impl Thread {
.run_pending_tool(tool_use_id, insert_output_task);
}
+ pub fn insert_scripting_tool_output(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ output: Task<Result<String>>,
+ cx: &mut Context<Self>,
+ ) {
+ let insert_output_task = cx.spawn(|thread, mut cx| {
+ let tool_use_id = tool_use_id.clone();
+ async move {
+ let output = output.await;
+ thread
+ .update(&mut cx, |thread, cx| {
+ thread
+ .scripting_tool_use
+ .insert_tool_output(tool_use_id.clone(), output);
+
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ })
+ .ok();
+ }
+ });
+
+ self.scripting_tool_use
+ .run_pending_tool(tool_use_id, insert_output_task);
+ }
+
pub fn send_tool_results_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -267,6 +267,7 @@ impl ToolUseState {
pub struct PendingToolUse {
pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
+ #[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub input: serde_json::Value,
@@ -14,7 +14,6 @@ doctest = false
[dependencies]
anyhow.workspace = true
-assistant_tool.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -3,40 +3,29 @@ mod session;
use project::Project;
use session::*;
-use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Entity, Task};
use schemars::JsonSchema;
use serde::Deserialize;
-use std::sync::Arc;
-
-pub fn init(cx: &App) {
- let registry = ToolRegistry::global(cx);
- registry.register_tool(ScriptingTool);
-}
#[derive(Debug, Deserialize, JsonSchema)]
struct ScriptingToolInput {
lua_script: String,
}
-struct ScriptingTool;
+pub struct ScriptingTool;
-impl Tool for ScriptingTool {
- fn name(&self) -> String {
- "lua-interpreter".into()
- }
+impl ScriptingTool {
+ pub const NAME: &str = "lua-interpreter";
- fn description(&self) -> String {
- include_str!("scripting_tool_description.txt").into()
- }
+ pub const DESCRIPTION: &str = include_str!("scripting_tool_description.txt");
- fn input_schema(&self) -> serde_json::Value {
+ pub fn input_schema() -> serde_json::Value {
let schema = schemars::schema_for!(ScriptingToolInput);
serde_json::to_value(&schema).unwrap()
}
- fn run(
- self: Arc<Self>,
+ pub fn run(
+ &self,
input: serde_json::Value,
project: Entity<Project>,
cx: &mut App,
@@ -98,7 +98,6 @@ remote.workspace = true
repl.workspace = true
reqwest_client.workspace = true
rope.workspace = true
-scripting_tool.workspace = true
search.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -476,7 +476,6 @@ fn main() {
cx,
);
assistant_tools::init(cx);
- scripting_tool::init(cx);
repl::init(app_state.fs.clone(), cx);
extension_host::init(
extension_host_proxy,