Detailed changes
@@ -450,7 +450,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_context_editor",
- "assistant_scripting",
"assistant_settings",
"assistant_slash_command",
"assistant_tool",
@@ -564,26 +563,6 @@ dependencies = [
"workspace",
]
-[[package]]
-name = "assistant_scripting"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "collections",
- "futures 0.3.31",
- "gpui",
- "log",
- "mlua",
- "parking_lot",
- "project",
- "rand 0.8.5",
- "regex",
- "serde",
- "serde_json",
- "settings",
- "util",
-]
-
[[package]]
name = "assistant_settings"
version = "0.1.0"
@@ -11931,6 +11910,28 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
+[[package]]
+name = "scripting_tool"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "assistant_tool",
+ "collections",
+ "futures 0.3.31",
+ "gpui",
+ "log",
+ "mlua",
+ "parking_lot",
+ "project",
+ "rand 0.8.5",
+ "regex",
+ "schemars",
+ "serde",
+ "serde_json",
+ "settings",
+ "util",
+]
+
[[package]]
name = "scrypt"
version = "0.11.0"
@@ -16985,6 +16986,7 @@ dependencies = [
"repl",
"reqwest_client",
"rope",
+ "scripting_tool",
"search",
"serde",
"serde_json",
@@ -8,7 +8,7 @@ members = [
"crates/assistant",
"crates/assistant2",
"crates/assistant_context_editor",
- "crates/assistant_scripting",
+ "crates/scripting_tool",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
-assistant_scripting = { path = "crates/assistant_scripting" }
+scripting_tool = { path = "crates/scripting_tool" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }
@@ -21,7 +21,6 @@ test-support = [
[dependencies]
anyhow.workspace = true
assistant_context_editor.workspace = true
-assistant_scripting.workspace = true
assistant_settings.workspace = true
assistant_slash_command.workspace = true
assistant_tool.workspace = true
@@ -1,12 +1,11 @@
use std::sync::Arc;
-use assistant_scripting::{ScriptId, ScriptState};
-use collections::{HashMap, HashSet};
+use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
- Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
+ Task, TextStyleRefinement, UnderlineStyle,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@@ -15,7 +14,6 @@ use settings::Settings as _;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
-use workspace::Workspace;
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
@@ -23,7 +21,6 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
pub struct ActiveThread {
- workspace: WeakEntity<Workspace>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
@@ -33,7 +30,6 @@ pub struct ActiveThread {
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
- expanded_scripts: HashSet<ScriptId>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
@@ -44,7 +40,6 @@ struct EditMessageState {
impl ActiveThread {
pub fn new(
- workspace: WeakEntity<Workspace>,
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
@@ -57,7 +52,6 @@ impl ActiveThread {
];
let mut this = Self {
- workspace,
language_registry,
thread_store,
thread: thread.clone(),
@@ -65,7 +59,6 @@ impl ActiveThread {
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
expanded_tool_uses: HashMap::default(),
- expanded_scripts: HashSet::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
@@ -466,10 +459,7 @@ impl ActiveThread {
let tool_uses = thread.tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results.
- if message.role == Role::User
- && (thread.message_has_tool_results(message_id)
- || thread.message_has_script_output(message_id))
- {
+ if message.role == Role::User && thread.message_has_tool_results(message_id) {
return Empty.into_any();
}
@@ -618,7 +608,6 @@ impl ActiveThread {
Role::Assistant => div()
.id(("message-container", ix))
.child(message_content)
- .children(self.render_script(message_id, cx))
.map(|parent| {
if tool_uses.is_empty() {
return parent;
@@ -738,139 +727,6 @@ impl ActiveThread {
}),
)
}
-
- fn render_script(&self, message_id: MessageId, cx: &mut Context<Self>) -> Option<AnyElement> {
- let script = self.thread.read(cx).script_for_message(message_id, cx)?;
-
- let is_open = self.expanded_scripts.contains(&script.id);
- let colors = cx.theme().colors();
-
- let element = div().px_2p5().child(
- v_flex()
- .gap_1()
- .rounded_lg()
- .border_1()
- .border_color(colors.border)
- .child(
- h_flex()
- .justify_between()
- .py_0p5()
- .pl_1()
- .pr_2()
- .bg(colors.editor_foreground.opacity(0.02))
- .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
- .when(!is_open, |element| element.rounded_md())
- .border_color(colors.border)
- .child(
- h_flex()
- .gap_1()
- .child(Disclosure::new("script-disclosure", is_open).on_click(
- cx.listener({
- let script_id = script.id;
- move |this, _event, _window, _cx| {
- if this.expanded_scripts.contains(&script_id) {
- this.expanded_scripts.remove(&script_id);
- } else {
- this.expanded_scripts.insert(script_id);
- }
- }
- }),
- ))
- // TODO: Generate script description
- .child(Label::new("Script")),
- )
- .child(
- h_flex()
- .gap_1()
- .child(
- Label::new(match script.state {
- ScriptState::Generating => "Generating",
- ScriptState::Running { .. } => "Running",
- ScriptState::Succeeded { .. } => "Finished",
- ScriptState::Failed { .. } => "Error",
- })
- .size(LabelSize::XSmall)
- .buffer_font(cx),
- )
- .child(
- IconButton::new("view-source", IconName::Eye)
- .icon_color(Color::Muted)
- .disabled(matches!(script.state, ScriptState::Generating))
- .on_click(cx.listener({
- let source = script.source.clone();
- move |this, _event, window, cx| {
- this.open_script_source(source.clone(), window, cx);
- }
- })),
- ),
- ),
- )
- .when(is_open, |parent| {
- let stdout = script.stdout_snapshot();
- let error = script.error();
-
- parent.child(
- v_flex()
- .p_2()
- .bg(colors.editor_background)
- .gap_2()
- .child(if stdout.is_empty() && error.is_none() {
- Label::new("No output yet")
- .size(LabelSize::Small)
- .color(Color::Muted)
- } else {
- Label::new(stdout).size(LabelSize::Small).buffer_font(cx)
- })
- .children(script.error().map(|err| {
- Label::new(err.to_string())
- .size(LabelSize::Small)
- .color(Color::Error)
- })),
- )
- }),
- );
-
- Some(element.into_any())
- }
-
- fn open_script_source(
- &mut self,
- source: SharedString,
- window: &mut Window,
- cx: &mut Context<'_, ActiveThread>,
- ) {
- let language_registry = self.language_registry.clone();
- let workspace = self.workspace.clone();
- let source = source.clone();
-
- cx.spawn_in(window, |_, mut cx| async move {
- let lua = language_registry.language_for_name("Lua").await.log_err();
-
- workspace.update_in(&mut cx, |workspace, window, cx| {
- let project = workspace.project().clone();
-
- let buffer = project.update(cx, |project, cx| {
- project.create_local_buffer(&source.trim(), lua, cx)
- });
-
- let buffer = cx.new(|cx| {
- MultiBuffer::singleton(buffer, cx)
- // TODO: Generate script description
- .with_title("Assistant script".into())
- });
-
- let editor = cx.new(|cx| {
- let mut editor =
- Editor::for_multibuffer(buffer, Some(project), true, window, cx);
- editor.set_read_only(true);
- editor
- });
-
- workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx);
- })
- })
- .detach_and_log_err(cx);
- }
}
impl Render for ActiveThread {
@@ -168,7 +168,6 @@ impl AssistantPanel {
let thread = cx.new(|cx| {
ActiveThread::new(
- workspace.clone(),
thread.clone(),
thread_store.clone(),
language_registry.clone(),
@@ -242,7 +241,6 @@ impl AssistantPanel {
self.active_view = ActiveView::Thread;
self.thread = cx.new(|cx| {
ActiveThread::new(
- self.workspace.clone(),
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
@@ -376,7 +374,6 @@ impl AssistantPanel {
this.active_view = ActiveView::Thread;
this.thread = cx.new(|cx| {
ActiveThread::new(
- this.workspace.clone(),
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
@@ -1,14 +1,11 @@
use std::sync::Arc;
use anyhow::Result;
-use assistant_scripting::{
- Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
-};
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _;
-use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
+use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@@ -78,21 +75,14 @@ pub struct Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
- scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
- script_output_messages: HashSet<MessageId>,
- script_session: Entity<ScriptSession>,
- _script_session_subscription: Subscription,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
- cx: &mut Context<Self>,
+ _cx: &mut Context<Self>,
) -> Self {
- let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
- let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
-
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@@ -107,10 +97,6 @@ impl Thread {
project,
tools,
tool_use: ToolUseState::new(),
- scripts_by_assistant_message: HashMap::default(),
- script_output_messages: HashSet::default(),
- script_session,
- _script_session_subscription: script_session_subscription,
}
}
@@ -119,7 +105,7 @@ impl Thread {
saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
- cx: &mut Context<Self>,
+ _cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
saved
@@ -129,8 +115,6 @@ impl Thread {
.unwrap_or(0),
);
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
- let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
- let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self {
id,
@@ -154,10 +138,6 @@ impl Thread {
project,
tools,
tool_use,
- scripts_by_assistant_message: HashMap::default(),
- script_output_messages: HashSet::default(),
- script_session,
- _script_session_subscription: script_session_subscription,
}
}
@@ -243,10 +223,6 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id)
}
- pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
- self.script_output_messages.contains(&message_id)
- }
-
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@@ -327,39 +303,6 @@ impl Thread {
text
}
- pub fn script_for_message<'a>(
- &'a self,
- message_id: MessageId,
- cx: &'a App,
- ) -> Option<&'a Script> {
- self.scripts_by_assistant_message
- .get(&message_id)
- .map(|script_id| self.script_session.read(cx).get(*script_id))
- }
-
- fn handle_script_event(
- &mut self,
- _script_session: Entity<ScriptSession>,
- event: &ScriptEvent,
- cx: &mut Context<Self>,
- ) {
- match event {
- ScriptEvent::Spawned(_) => {}
- ScriptEvent::Exited(script_id) => {
- if let Some(output_message) = self
- .script_session
- .read(cx)
- .get(*script_id)
- .output_message_for_llm()
- {
- let message_id = self.insert_user_message(output_message, vec![], cx);
- self.script_output_messages.insert(message_id);
- cx.emit(ThreadEvent::ScriptFinished)
- }
- }
- }
- }
-
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -388,7 +331,7 @@ impl Thread {
pub fn to_completion_request(
&self,
request_kind: RequestKind,
- cx: &App,
+ _cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
messages: vec![],
@@ -397,12 +340,6 @@ impl Thread {
temperature: None,
};
- request.messages.push(LanguageModelRequestMessage {
- role: Role::System,
- content: vec![SCRIPTING_PROMPT.to_string().into()],
- cache: true,
- });
-
let mut referenced_context_ids = HashSet::default();
for message in &self.messages {
@@ -436,15 +373,6 @@ impl Thread {
RequestKind::Chat => {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
-
- if matches!(message.role, Role::Assistant) {
- if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
- {
- let script = self.script_session.read(cx).get(*script_id);
-
- request_message.content.push(script.source_tag().into());
- }
- }
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
@@ -486,8 +414,6 @@ impl Thread {
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
- let mut script_tag_parser = ScriptTagParser::new();
- let mut script_id = None;
while let Some(event) = events.next().await {
let event = event?;
@@ -502,44 +428,20 @@ impl Thread {
}
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
- let chunk = script_tag_parser.parse_chunk(&chunk);
-
- let message_id = if last_message.role == Role::Assistant {
- last_message.text.push_str(&chunk.content);
+ if last_message.role == Role::Assistant {
+ last_message.text.push_str(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
- chunk.content,
+ chunk,
));
- last_message.id
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
- thread.insert_message(Role::Assistant, chunk.content, cx)
+ thread.insert_message(Role::Assistant, chunk, cx);
};
-
- if script_id.is_none() && script_tag_parser.found_script() {
- let id = thread
- .script_session
- .update(cx, |session, _cx| session.new_script());
- thread.scripts_by_assistant_message.insert(message_id, id);
-
- script_id = Some(id);
- }
-
- if let (Some(script_source), Some(script_id)) =
- (chunk.script_source, script_id)
- {
- // TODO: move buffer to script and run as it streams
- thread
- .script_session
- .update(cx, |this, cx| {
- this.run_script(script_id, script_source, cx)
- })
- .detach_and_log_err(cx);
- }
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
@@ -1,7 +0,0 @@
-mod session;
-mod tag;
-
-pub use session::*;
-pub use tag::*;
-
-pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt");
@@ -1,36 +0,0 @@
-You can write a Lua script and I'll run it on my codebase and tell you what its
-output was, including both stdout as well as the git diff of changes it made to
-the filesystem. That way, you can get more information about the code base, or
-make changes to the code base directly.
-
-Put the Lua script inside of an `<eval>` tag like so:
-
-<eval type="lua">
-print("Hello, world!")
-</eval>
-
-The Lua script will have access to `io` and it will run with the current working
-directory being in the root of the code base, so you can use it to explore,
-search, make changes, etc. You can also have the script print things, and I'll
-tell you what the output was. Note that `io` only has `open`, and then the file
-it returns only has the methods read, write, and close - it doesn't have popen
-or anything else.
-
-There is a function called `search` which accepts a regex (it's implemented
-using Rust's regex crate, so use that regex syntax) and runs that regex on the
-contents of every file in the code base (aside from gitignored files), then
-returns an array of tables with two fields: "path" (the path to the file that
-had the matches) and "matches" (an array of strings, with each string being a
-match that was found within the file).
-
-There is a function called `outline` which accepts the path to a source file,
-and returns a string where each line is a declaration. These lines are indented
-with 2 spaces to indicate when a declaration is inside another.
-
-When I send you the script output, do not thank me for running it,
-act as if you ran it yourself.
-
-IMPORTANT!
-Only include a maximum of one Lua script at the very end of your message
-DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script
-output to continue.
@@ -1,260 +0,0 @@
-pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
-pub const SCRIPT_END_TAG: &str = "</eval>";
-
-const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
-const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
-
-/// Parses a script tag in an assistant message as it is being streamed.
-pub struct ScriptTagParser {
- state: State,
- buffer: Vec<u8>,
- tag_match_ix: usize,
-}
-
-enum State {
- Unstarted,
- Streaming,
- Ended,
-}
-
-#[derive(Debug, PartialEq)]
-pub struct ChunkOutput {
- /// The chunk with script tags removed.
- pub content: String,
- /// The full script tag content. `None` until closed.
- pub script_source: Option<String>,
-}
-
-impl ScriptTagParser {
- /// Create a new script tag parser.
- pub fn new() -> Self {
- Self {
- state: State::Unstarted,
- buffer: Vec::new(),
- tag_match_ix: 0,
- }
- }
-
- /// Returns true if the parser has found a script tag.
- pub fn found_script(&self) -> bool {
- match self.state {
- State::Unstarted => false,
- State::Streaming | State::Ended => true,
- }
- }
-
- /// Process a new chunk of input, splitting it into surrounding content and script source.
- pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
- let mut content = Vec::with_capacity(input.len());
-
- for byte in input.bytes() {
- match self.state {
- State::Unstarted => {
- if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
- self.state = State::Streaming;
- self.buffer = Vec::with_capacity(1024);
- self.tag_match_ix = 0;
- }
- }
- State::Streaming => {
- if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
- self.state = State::Ended;
- }
- }
- State::Ended => content.push(byte),
- }
- }
-
- let content = unsafe { String::from_utf8_unchecked(content) };
-
- let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
- let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
-
- Some(source)
- } else {
- None
- };
-
- ChunkOutput {
- content,
- script_source,
- }
- }
-}
-
-fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
- // this can't be a method because it'd require a mutable borrow on both self and self.buffer
-
- if match_tag_byte(byte, tag, tag_match_ix) {
- *tag_match_ix >= tag.len()
- } else {
- if *tag_match_ix > 0 {
- // push the partially matched tag to the buffer
- buffer.extend_from_slice(&tag[..*tag_match_ix]);
- *tag_match_ix = 0;
-
- // the tag might start to match again
- if match_tag_byte(byte, tag, tag_match_ix) {
- return *tag_match_ix >= tag.len();
- }
- }
-
- buffer.push(byte);
-
- false
- }
-}
-
-fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
- if byte == tag[*tag_match_ix] {
- *tag_match_ix += 1;
- true
- } else {
- false
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_parse_complete_tag() {
- let mut parser = ScriptTagParser::new();
- let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
- let result = parser.parse_chunk(input);
- assert_eq!(result.content, "");
- assert_eq!(
- result.script_source,
- Some("print(\"Hello, World!\")".to_string())
- );
- }
-
- #[test]
- fn test_no_tag() {
- let mut parser = ScriptTagParser::new();
- let input = "No tags here, just plain text";
- let result = parser.parse_chunk(input);
- assert_eq!(result.content, "No tags here, just plain text");
- assert_eq!(result.script_source, None);
- }
-
- #[test]
- fn test_partial_end_tag() {
- let mut parser = ScriptTagParser::new();
-
- // Start the tag
- let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
- assert_eq!(result.content, "");
- assert_eq!(result.script_source, None);
-
- // Finish with the rest
- let result = parser.parse_chunk("val' + 'not the end';</eval>");
- assert_eq!(result.content, "");
- assert_eq!(
- result.script_source,
- Some("let x = '</eval' + 'not the end';".to_string())
- );
- }
-
- #[test]
- fn test_text_before_and_after_tag() {
- let mut parser = ScriptTagParser::new();
- let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
- let result = parser.parse_chunk(input);
- assert_eq!(result.content, "Before tag After tag");
- assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
- }
-
- #[test]
- fn test_multiple_chunks_with_surrounding_text() {
- let mut parser = ScriptTagParser::new();
-
- // First chunk with text before
- let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
- assert_eq!(result.content, "Before script ");
- assert_eq!(result.script_source, None);
-
- // Second chunk with script content
- let result = parser.parse_chunk("\nlocal y = 20");
- assert_eq!(result.content, "");
- assert_eq!(result.script_source, None);
-
- // Last chunk with text after
- let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
- assert_eq!(result.content, " After script");
- assert_eq!(
- result.script_source,
- Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
- );
-
- let result = parser.parse_chunk(" there's more text");
- assert_eq!(result.content, " there's more text");
- assert_eq!(result.script_source, None);
- }
-
- #[test]
- fn test_partial_start_tag_matching() {
- let mut parser = ScriptTagParser::new();
-
- // partial match of start tag...
- let result = parser.parse_chunk("<ev");
- assert_eq!(result.content, "");
-
- // ...that's abandandoned when the < of a real tag is encountered
- let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
- // ...so it gets pushed to content
- assert_eq!(result.content, "<ev");
- // ...and the real tag is parsed correctly
- assert_eq!(result.script_source, Some("script content".to_string()));
- }
-
- #[test]
- fn test_random_chunked_parsing() {
- use rand::rngs::StdRng;
- use rand::{Rng, SeedableRng};
- use std::time::{SystemTime, UNIX_EPOCH};
-
- let test_inputs = [
- "Before <eval type=\"lua\">print(\"Hello\")</eval> After",
- "No tags here at all",
- "<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
- "Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
- ];
-
- let seed = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs();
-
- eprintln!("Using random seed: {}", seed);
- let mut rng = StdRng::seed_from_u64(seed);
-
- for test_input in &test_inputs {
- let mut reference_parser = ScriptTagParser::new();
- let expected = reference_parser.parse_chunk(test_input);
-
- let mut chunked_parser = ScriptTagParser::new();
- let mut remaining = test_input.as_bytes();
- let mut actual_content = String::new();
- let mut actual_script = None;
-
- while !remaining.is_empty() {
- let chunk_size = rng.gen_range(1..=remaining.len().min(5));
- let (chunk, rest) = remaining.split_at(chunk_size);
- remaining = rest;
-
- let chunk_str = std::str::from_utf8(chunk).unwrap();
- let result = chunked_parser.parse_chunk(chunk_str);
-
- actual_content.push_str(&result.content);
- if result.script_source.is_some() {
- actual_script = result.script_source;
- }
- }
-
- assert_eq!(actual_content, expected.content);
- assert_eq!(actual_script, expected.script_source);
- }
- }
-}
@@ -1,5 +1,5 @@
[package]
-name = "assistant_scripting"
+name = "scripting_tool"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,11 +9,12 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
-path = "src/assistant_scripting.rs"
+path = "src/scripting_tool.rs"
doctest = false
[dependencies]
anyhow.workspace = true
+assistant_tool.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -22,6 +23,7 @@ mlua.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
+schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -0,0 +1,74 @@
+mod session;
+
+use project::Project;
+use session::*;
+
+use assistant_tool::{Tool, ToolRegistry};
+use gpui::{App, AppContext as _, Entity, Task};
+use schemars::JsonSchema;
+use serde::Deserialize;
+use std::sync::Arc;
+
+pub fn init(cx: &App) {
+ let registry = ToolRegistry::global(cx);
+ registry.register_tool(ScriptingTool);
+}
+
+#[derive(Debug, Deserialize, JsonSchema)]
+struct ScriptingToolInput {
+ lua_script: String,
+}
+
+struct ScriptingTool;
+
+impl Tool for ScriptingTool {
+ fn name(&self) -> String {
+ "lua-interpreter".into()
+ }
+
+ fn description(&self) -> String {
+ include_str!("scripting_tool_description.txt").into()
+ }
+
+ fn input_schema(&self) -> serde_json::Value {
+ let schema = schemars::schema_for!(ScriptingToolInput);
+ serde_json::to_value(&schema).unwrap()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Task<anyhow::Result<String>> {
+ let input = match serde_json::from_value::<ScriptingToolInput>(input) {
+ Err(err) => return Task::ready(Err(err.into())),
+ Ok(input) => input,
+ };
+
+ // TODO: Store a session per thread
+ let session = cx.new(|cx| ScriptSession::new(project, cx));
+ let lua_script = input.lua_script;
+
+ let (script_id, script_task) =
+ session.update(cx, |session, cx| session.run_script(lua_script, cx));
+
+ cx.spawn(|cx| async move {
+ script_task.await;
+
+ let message = session.read_with(&cx, |session, _cx| {
+ // Using a id to get the script output seems impractical.
+ // Why not just include it in the Task result?
+ // This is because we'll later report the script state as it runs,
+ // currently not supported by the `Tool` interface.
+ session
+ .get(script_id)
+ .output_message_for_llm()
+ .expect("Script shouldn't still be running")
+ })?;
+
+ drop(session);
+ Ok(message)
+ })
+ }
+}
@@ -0,0 +1,22 @@
+You can write a Lua script and I'll run it on my codebase and tell you what its
+output was, including both stdout as well as the git diff of changes it made to
+ the filesystem. That way, you can get more information about the code base, or
+ make changes to the code base directly.
+
+ The Lua script will have access to `io` and it will run with the current working
+ directory being in the root of the code base, so you can use it to explore,
+ search, make changes, etc. You can also have the script print things, and I'll
+ tell you what the output was. Note that `io` only has `open`, and then the file
+ it returns only has the methods read, write, and close - it doesn't have popen
+ or anything else.
+
+ Also, I'm going to be putting this Lua script into JSON, so please don't use
+ Lua's double quote syntax for string literals - use one of Lua's other syntaxes
+ for string literals, so I don't have to escape the double quotes.
+
+ There will be a global called `search` which accepts a regex (it's implemented
+ using Rust's regex crate, so use that regex syntax) and runs that regex on the
+ contents of every file in the code base (aside from gitignored files), then
+ returns an array of tables with two fields: "path" (the path to the file that
+ had the matches) and "matches" (an array of strings, with each string being a
+ match that was found within the file).
@@ -4,7 +4,7 @@ use futures::{
channel::{mpsc, oneshot},
pin_mut, SinkExt, StreamExt,
};
-use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
+use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
use parking_lot::Mutex;
use project::{search::SearchQuery, Fs, Project};
@@ -16,8 +16,6 @@ use std::{
};
use util::{paths::PathMatcher, ResultExt};
-use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
-
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
pub struct ScriptSession {
@@ -45,50 +43,41 @@ impl ScriptSession {
}
}
- pub fn new_script(&mut self) -> ScriptId {
- let id = ScriptId(self.scripts.len() as u32);
- let script = Script {
- id,
- state: ScriptState::Generating,
- source: SharedString::new_static(""),
- };
- self.scripts.push(script);
- id
- }
-
pub fn run_script(
&mut self,
- script_id: ScriptId,
script_src: String,
cx: &mut Context<Self>,
- ) -> Task<anyhow::Result<()>> {
- let script = self.get_mut(script_id);
+ ) -> (ScriptId, Task<()>) {
+ let id = ScriptId(self.scripts.len() as u32);
let stdout = Arc::new(Mutex::new(String::new()));
- script.source = script_src.clone().into();
- script.state = ScriptState::Running {
- stdout: stdout.clone(),
+
+ let script = Script {
+ state: ScriptState::Running {
+ stdout: stdout.clone(),
+ },
};
+ self.scripts.push(script);
let task = self.run_lua(script_src, stdout, cx);
- cx.emit(ScriptEvent::Spawned(script_id));
-
- cx.spawn(|session, mut cx| async move {
+ let task = cx.spawn(|session, mut cx| async move {
let result = task.await;
- session.update(&mut cx, |session, cx| {
- let script = session.get_mut(script_id);
- let stdout = script.stdout_snapshot();
+ session
+ .update(&mut cx, |session, _cx| {
+ let script = session.get_mut(id);
+ let stdout = script.stdout_snapshot();
- script.state = match result {
- Ok(()) => ScriptState::Succeeded { stdout },
- Err(error) => ScriptState::Failed { stdout, error },
- };
+ script.state = match result {
+ Ok(()) => ScriptState::Succeeded { stdout },
+ Err(error) => ScriptState::Failed { stdout, error },
+ };
+ })
+ .log_err();
+ });
- cx.emit(ScriptEvent::Exited(script_id))
- })
- })
+ (id, task)
}
fn run_lua(
@@ -808,25 +797,14 @@ impl UserData for FileContent {
}
}
-#[derive(Debug)]
-pub enum ScriptEvent {
- Spawned(ScriptId),
- Exited(ScriptId),
-}
-
-impl EventEmitter<ScriptEvent> for ScriptSession {}
-
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ScriptId(u32);
pub struct Script {
- pub id: ScriptId,
pub state: ScriptState,
- pub source: SharedString,
}
pub enum ScriptState {
- Generating,
Running {
stdout: Arc<Mutex<String>>,
},
@@ -840,14 +818,9 @@ pub enum ScriptState {
}
impl Script {
- pub fn source_tag(&self) -> String {
- format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
- }
-
/// If exited, returns a message with the output for the LLM
pub fn output_message_for_llm(&self) -> Option<String> {
match &self.state {
- ScriptState::Generating { .. } => None,
ScriptState::Running { .. } => None,
ScriptState::Succeeded { stdout } => {
format!("Here's the script output:\n{}", stdout).into()
@@ -863,22 +836,11 @@ impl Script {
/// Get a snapshot of the script's stdout
pub fn stdout_snapshot(&self) -> String {
match &self.state {
- ScriptState::Generating { .. } => String::new(),
ScriptState::Running { stdout } => stdout.lock().clone(),
ScriptState::Succeeded { stdout } => stdout.clone(),
ScriptState::Failed { stdout, .. } => stdout.clone(),
}
}
-
- /// Returns the error if the script failed, otherwise None
- pub fn error(&self) -> Option<&anyhow::Error> {
- match &self.state {
- ScriptState::Generating { .. } => None,
- ScriptState::Running { .. } => None,
- ScriptState::Succeeded { .. } => None,
- ScriptState::Failed { error, .. } => Some(error),
- }
- }
}
#[cfg(test)]
@@ -933,14 +895,10 @@ mod tests {
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| ScriptSession::new(project, cx));
- let (script_id, task) = session.update(cx, |session, cx| {
- let script_id = session.new_script();
- let task = session.run_script(script_id, source.to_string(), cx);
-
- (script_id, task)
- });
+ let (script_id, task) =
+ session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
- task.await?;
+ task.await;
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
}
@@ -98,6 +98,7 @@ remote.workspace = true
repl.workspace = true
reqwest_client.workspace = true
rope.workspace = true
+scripting_tool.workspace = true
search.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -476,6 +476,7 @@ fn main() {
cx,
);
assistant_tools::init(cx);
+ scripting_tool::init(cx);
repl::init(app_state.fs.clone(), cx);
extension_host::init(
extension_host_proxy,