@@ -1,16 +1,22 @@
use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry};
-use gpui::{App, AppContext as _, Task, WeakEntity, Window};
+use futures::{
+ channel::{mpsc, oneshot},
+ SinkExt, StreamExt as _,
+};
+use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window};
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
+use parking_lot::Mutex;
+use project::{search::SearchQuery, ProjectPath, WorktreeId};
use schemars::JsonSchema;
use serde::Deserialize;
use std::{
cell::RefCell,
- collections::HashMap,
+ collections::{HashMap, HashSet},
path::{Path, PathBuf},
- rc::Rc,
sync::Arc,
};
+use util::paths::PathMatcher;
use workspace::Workspace;
pub fn init(cx: &App) {
@@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into()
_window: &mut Window,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
- let root_dir = workspace.update(cx, |workspace, cx| {
+ let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| {
let first_worktree = workspace
.visible_worktrees(cx)
.next()
.ok_or_else(|| anyhow!("no worktrees"))?;
- workspace
- .absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
- .ok_or_else(|| anyhow!("no worktree root"))
+ let worktree_id = first_worktree.read(cx).id();
+ let root_dir = workspace
+ .absolute_path_of_worktree(worktree_id, cx)
+ .ok_or_else(|| anyhow!("no worktree root"))?;
+ Ok((root_dir, worktree_id))
});
- let root_dir = match root_dir {
- Ok(root_dir) => root_dir,
- Err(err) => return Task::ready(Err(err)),
- };
- let root_dir = match root_dir {
- Ok(root_dir) => root_dir,
+ let (root_dir, worktree_id) = match worktree_root_dir_and_id {
+ Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id,
+ Ok(Err(err)) => return Task::ready(Err(err)),
Err(err) => return Task::ready(Err(err)),
};
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
+
+ let (foreground_tx, mut foreground_rx) = mpsc::channel::<ForegroundFn>(1);
+
+ cx.spawn(move |cx| async move {
+ while let Some(request) = foreground_rx.next().await {
+ request.0(cx.clone());
+ }
+ })
+ .detach();
+
let lua_script = input.lua_script;
cx.background_spawn(async move {
let fs_changes = HashMap::new();
- let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
- .map_err(|err| anyhow!(format!("{err}")))?;
+ let output = run_sandboxed_lua(
+ &lua_script,
+ fs_changes,
+ root_dir,
+ worktree_id,
+ workspace,
+ foreground_tx,
+ )
+ .await
+ .map_err(|err| anyhow!(format!("{err}")))?;
let output = output.printed_lines.join("\n");
Ok(format!("The script output the following:\n{output}"))
@@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into()
}
}
+struct ForegroundFn(Box<dyn FnOnce(AsyncApp) + Send>);
+
+async fn run_foreground_fn<R: Send + 'static>(
+ description: &str,
+ foreground_tx: &mut mpsc::Sender<ForegroundFn>,
+ function: Box<dyn FnOnce(AsyncApp) -> anyhow::Result<R> + Send>,
+) -> Result<R> {
+ let (response_tx, response_rx) = oneshot::channel();
+ let send_result = foreground_tx
+ .send(ForegroundFn(Box::new(move |cx| {
+ response_tx.send(function(cx)).ok();
+ })))
+ .await;
+ match send_result {
+ Ok(()) => (),
+ Err(err) => {
+ return Err(mlua::Error::runtime(format!(
+ "Internal error while enqueuing work for {description}: {err}"
+ )))
+ }
+ }
+ match response_rx.await {
+ Ok(Ok(result)) => Ok(result),
+ Ok(Err(err)) => Err(mlua::Error::runtime(format!(
+ "Error while {description}: {err}"
+ ))),
+ Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
+ "Internal error: response oneshot was canceled while {description}."
+ ))),
+ }
+}
+
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
struct FileContent(RefCell<Vec<u8>>);
@@ -103,7 +158,7 @@ impl UserData for FileContent {
}
/// Sandboxed print() function in Lua.
-fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
+fn print(lua: &Lua, printed_lines: Arc<Mutex<Vec<String>>>) -> Result<Function> {
lua.create_function(move |_, args: MultiValue| {
let mut string = String::new();
@@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
string.push_str(arg.to_string()?.as_str())
}
- printed_lines.borrow_mut().push(string);
+ printed_lines.lock().push(string);
Ok(())
})
@@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
fn search(
lua: &Lua,
- _fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
+ _fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
+ worktree_id: WorktreeId,
+ workspace: WeakEntity<Workspace>,
+ foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<Function> {
- lua.create_function(move |lua, regex: String| {
- use mlua::Table;
- use regex::Regex;
- use std::fs;
-
- // Function to recursively search directory
- let search_regex = match Regex::new(®ex) {
- Ok(re) => re,
- Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
- };
+ lua.create_async_function(move |lua, regex: String| {
+ let root_dir = root_dir.clone();
+ let workspace = workspace.clone();
+ let mut foreground_tx = foreground_tx.clone();
+ async move {
+ use mlua::Table;
+ use regex::Regex;
+ use std::fs;
+
+ // TODO: Allow specification of these options.
+ let search_query = SearchQuery::regex(
+ ®ex,
+ false,
+ false,
+ false,
+ PathMatcher::default(),
+ PathMatcher::default(),
+ None,
+ );
+ let search_query = match search_query {
+ Ok(query) => query,
+ Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
+ };
- let mut search_results: Vec<Result<Table>> = Vec::new();
+ // TODO: Should use `search_query.regex`. The tool description should also be updated,
+ // as it specifies standard regex.
+ let search_regex = match Regex::new(®ex) {
+ Ok(re) => re,
+ Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
+ };
- // Create an explicit stack for directories to process
- let mut dir_stack = vec![root_dir.clone()];
+ let project_path_rx =
+ find_search_candidates(search_query, workspace, &mut foreground_tx).await?;
- while let Some(current_dir) = dir_stack.pop() {
- // Process each entry in the current directory
- let entries = match fs::read_dir(¤t_dir) {
- Ok(entries) => entries,
- Err(e) => return Err(e.into()),
- };
+ let mut search_results: Vec<Result<Table>> = Vec::new();
+ while let Ok(project_path) = project_path_rx.recv().await {
+ if project_path.worktree_id != worktree_id {
+ continue;
+ }
- for entry_result in entries {
- let entry = match entry_result {
- Ok(e) => e,
- Err(e) => return Err(e.into()),
- };
-
- let path = entry.path();
-
- if path.is_dir() {
- // Skip .git directory and other common directories to ignore
- let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
- if !dir_name.starts_with('.')
- && dir_name != "node_modules"
- && dir_name != "target"
- {
- // Instead of recursive call, add to stack
- dir_stack.push(path);
- }
- } else if path.is_file() {
- // Skip binary files and very large files
- if let Ok(metadata) = fs::metadata(&path) {
- if metadata.len() > 1_000_000 {
- // Skip files larger than 1MB
- continue;
- }
+ let path = root_dir.join(project_path.path);
+
+ // Skip files larger than 1MB
+ if let Ok(metadata) = fs::metadata(&path) {
+ if metadata.len() > 1_000_000 {
+ continue;
}
+ }
- // Attempt to read the file as text
- if let Ok(content) = fs::read_to_string(&path) {
- let mut matches = Vec::new();
+ // Attempt to read the file as text
+ if let Ok(content) = fs::read_to_string(&path) {
+ let mut matches = Vec::new();
- // Find all regex matches in the content
- for capture in search_regex.find_iter(&content) {
- matches.push(capture.as_str().to_string());
- }
-
- // If we found matches, create a result entry
- if !matches.is_empty() {
- let result_entry = lua.create_table()?;
- result_entry.set("path", path.to_string_lossy().to_string())?;
+ // Find all regex matches in the content
+ for capture in search_regex.find_iter(&content) {
+ matches.push(capture.as_str().to_string());
+ }
- let matches_table = lua.create_table()?;
- for (i, m) in matches.iter().enumerate() {
- matches_table.set(i + 1, m.clone())?;
- }
- result_entry.set("matches", matches_table)?;
+ // If we found matches, create a result entry
+ if !matches.is_empty() {
+ let result_entry = lua.create_table()?;
+ result_entry.set("path", path.to_string_lossy().to_string())?;
- search_results.push(Ok(result_entry));
+ let matches_table = lua.create_table()?;
+ for (i, m) in matches.iter().enumerate() {
+ matches_table.set(i + 1, m.clone())?;
}
+ result_entry.set("matches", matches_table)?;
+
+ search_results.push(Ok(result_entry));
}
}
}
- }
- // Create a table to hold our results
- let results_table = lua.create_table()?;
- for (i, result) in search_results.into_iter().enumerate() {
- match result {
- Ok(entry) => results_table.set(i + 1, entry)?,
- Err(e) => return Err(e),
+ // Create a table to hold our results
+ let results_table = lua.create_table()?;
+ for (i, result) in search_results.into_iter().enumerate() {
+ match result {
+ Ok(entry) => results_table.set(i + 1, entry)?,
+ Err(e) => return Err(e),
+ }
}
- }
- Ok(results_table)
+ Ok(results_table)
+ }
})
}
+async fn find_search_candidates(
+ search_query: SearchQuery,
+ workspace: WeakEntity<Workspace>,
+ foreground_tx: &mut mpsc::Sender<ForegroundFn>,
+) -> Result<smol::channel::Receiver<ProjectPath>> {
+ run_foreground_fn(
+ "finding search file candidates",
+ foreground_tx,
+ Box::new(move |mut cx| {
+ workspace.update(&mut cx, move |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ project.worktree_store().update(cx, |worktree_store, cx| {
+ // TODO: Better limit? For now this is the same as
+ // MAX_SEARCH_RESULT_FILES.
+ let limit = 5000;
+ // TODO: Providing non-empty open_entries can make this a bit more
+ // efficient as it can skip checking that these paths are textual.
+ let open_entries = HashSet::default();
+ // TODO: This is searching all worktrees, but should only search the
+ // current worktree
+ worktree_store.find_search_candidates(
+ search_query,
+ limit,
+ open_entries,
+ project.fs().clone(),
+ cx,
+ )
+ })
+ })
+ })
+ }),
+ )
+ .await
+}
+
/// Sandboxed io.open() function in Lua.
fn io_open(
lua: &Lua,
- fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
+ fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
) -> Result<Function> {
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@@ -281,7 +372,7 @@ fn io_open(
// Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path);
fs_changes
- .borrow_mut()
+ .lock()
.insert(path_buf.clone(), content_vec.clone());
}
@@ -333,13 +424,15 @@ fn io_open(
return Ok((Some(file), String::new()));
}
- let is_in_changes = fs_changes.borrow().contains_key(&path);
+ let fs_changes_map = fs_changes.lock();
+
+ let is_in_changes = fs_changes_map.contains_key(&path);
let file_exists = is_in_changes || path.exists();
let mut file_content = Vec::new();
if file_exists && !truncate {
if is_in_changes {
- file_content = fs_changes.borrow().get(&path).unwrap().clone();
+ file_content = fs_changes_map.get(&path).unwrap().clone();
} else {
// Try to read existing content if file exists and we're not truncating
match std::fs::read(&path) {
@@ -349,6 +442,8 @@ fn io_open(
}
}
+ drop(fs_changes_map); // Unlock the fs_changes mutex.
+
// If in append mode, position should be at the end
let position = if append && file_exists {
file_content.len()
@@ -582,9 +677,7 @@ fn io_open(
// Update fs_changes
let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path);
- fs_changes
- .borrow_mut()
- .insert(path_buf, content_vec.clone());
+ fs_changes.lock().insert(path_buf, content_vec.clone());
Ok(true)
})?
@@ -597,33 +690,46 @@ fn io_open(
}
/// Runs a Lua script in a sandboxed environment and returns the printed lines
-pub fn run_sandboxed_lua(
+async fn run_sandboxed_lua(
script: &str,
fs_changes: HashMap<PathBuf, Vec<u8>>,
root_dir: PathBuf,
+ worktree_id: WorktreeId,
+ workspace: WeakEntity<Workspace>,
+ foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<ScriptOutput> {
let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals();
// Track the lines the Lua script prints out.
- let printed_lines = Rc::new(RefCell::new(Vec::new()));
- let fs = Rc::new(RefCell::new(fs_changes));
+ let printed_lines = Arc::new(Mutex::new(Vec::new()));
+ let fs = Arc::new(Mutex::new(fs_changes));
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
- globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
+ globals.set(
+ "search",
+ search(
+ &lua,
+ fs.clone(),
+ root_dir.clone(),
+ worktree_id,
+ workspace,
+ foreground_tx,
+ )?,
+ )?;
globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
globals.set("user_script", script)?;
- lua.load(SANDBOX_PREAMBLE).exec()?;
+ lua.load(SANDBOX_PREAMBLE).exec_async().await?;
- drop(lua); // Necessary so the Rc'd values get decremented.
+ drop(lua); // Decrements the Arcs so that try_unwrap works.
Ok(ScriptOutput {
- printed_lines: Rc::try_unwrap(printed_lines)
+ printed_lines: Arc::try_unwrap(printed_lines)
.expect("There are still other references to printed_lines")
.into_inner(),
- fs_changes: Rc::try_unwrap(fs)
+ fs_changes: Arc::try_unwrap(fs)
.expect("There are still other references to fs_changes")
.into_inner(),
})