Add lua script access to code using `cx` + reuse project search logic (#26269)

Michael Sloan created

Access to `cx` will be needed for anything that queries entities. In
this commit this is use of `WorktreeStore::find_search_candidates`. In
the future it will be things like access to LSP / tree-sitter outlines /
etc.

Changes to support access to `cx` from functions provided to the Lua
script:

* Adds a channel of requests that require a `cx`. Work enqueued to this
channel is run on the foreground thread.

* Adds `async` and `send` features to `mlua` crate so that async rust
functions can be used from Lua.

* Changes uses of `Rc<RefCell<...>>` to `Arc<Mutex<...>>` so that the
futures are `Send`.

One benefit of reusing project search logic for search candidates is
that it properly ignores paths.

Release Notes:

- N/A

Change summary

Cargo.lock                                  |   6 
Cargo.toml                                  |   2 
crates/scripting_tool/Cargo.toml            |   7 
crates/scripting_tool/src/scripting_tool.rs | 314 +++++++++++++++-------
4 files changed, 223 insertions(+), 106 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8134,6 +8134,7 @@ checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b"
 dependencies = [
  "bstr",
  "either",
+ "futures-util",
  "mlua-sys",
  "num-traits",
  "parking_lot",
@@ -11915,12 +11916,17 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "assistant_tool",
+ "futures 0.3.31",
  "gpui",
  "mlua",
+ "parking_lot",
+ "project",
  "regex",
  "schemars",
  "serde",
  "serde_json",
+ "smol",
+ "util",
  "workspace",
 ]
 

Cargo.toml 🔗

@@ -452,7 +452,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "
 ], default-features = false }
 log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
 markup5ever_rcdom = "0.3.0"
-mlua = { version = "0.10", features = ["lua54", "vendored"] }
+mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
 nanoid = "0.4"
 nbformat = { version = "0.10.0" }
 nix = "0.29"

crates/scripting_tool/Cargo.toml 🔗

@@ -15,10 +15,15 @@ doctest = false
 [dependencies]
 anyhow.workspace = true
 assistant_tool.workspace = true
+futures.workspace = true
 gpui.workspace = true
 mlua.workspace = true
+parking_lot.workspace = true
+project.workspace = true
+regex.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+smol.workspace = true
+util.workspace = true
 workspace.workspace = true
-regex.workspace = true

crates/scripting_tool/src/scripting_tool.rs 🔗

@@ -1,16 +1,22 @@
 use anyhow::anyhow;
 use assistant_tool::{Tool, ToolRegistry};
-use gpui::{App, AppContext as _, Task, WeakEntity, Window};
+use futures::{
+    channel::{mpsc, oneshot},
+    SinkExt, StreamExt as _,
+};
+use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window};
 use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
+use parking_lot::Mutex;
+use project::{search::SearchQuery, ProjectPath, WorktreeId};
 use schemars::JsonSchema;
 use serde::Deserialize;
 use std::{
     cell::RefCell,
-    collections::HashMap,
+    collections::{HashMap, HashSet},
     path::{Path, PathBuf},
-    rc::Rc,
     sync::Arc,
 };
+use util::paths::PathMatcher;
 use workspace::Workspace;
 
 pub fn init(cx: &App) {
@@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into()
         _window: &mut Window,
         cx: &mut App,
     ) -> Task<anyhow::Result<String>> {
-        let root_dir = workspace.update(cx, |workspace, cx| {
+        let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| {
             let first_worktree = workspace
                 .visible_worktrees(cx)
                 .next()
                 .ok_or_else(|| anyhow!("no worktrees"))?;
-            workspace
-                .absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
-                .ok_or_else(|| anyhow!("no worktree root"))
+            let worktree_id = first_worktree.read(cx).id();
+            let root_dir = workspace
+                .absolute_path_of_worktree(worktree_id, cx)
+                .ok_or_else(|| anyhow!("no worktree root"))?;
+            Ok((root_dir, worktree_id))
         });
-        let root_dir = match root_dir {
-            Ok(root_dir) => root_dir,
-            Err(err) => return Task::ready(Err(err)),
-        };
-        let root_dir = match root_dir {
-            Ok(root_dir) => root_dir,
+        let (root_dir, worktree_id) = match worktree_root_dir_and_id {
+            Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id,
+            Ok(Err(err)) => return Task::ready(Err(err)),
             Err(err) => return Task::ready(Err(err)),
         };
         let input = match serde_json::from_value::<ScriptingToolInput>(input) {
             Err(err) => return Task::ready(Err(err.into())),
             Ok(input) => input,
         };
+
+        let (foreground_tx, mut foreground_rx) = mpsc::channel::<ForegroundFn>(1);
+
+        cx.spawn(move |cx| async move {
+            while let Some(request) = foreground_rx.next().await {
+                request.0(cx.clone());
+            }
+        })
+        .detach();
+
         let lua_script = input.lua_script;
         cx.background_spawn(async move {
             let fs_changes = HashMap::new();
-            let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
-                .map_err(|err| anyhow!(format!("{err}")))?;
+            let output = run_sandboxed_lua(
+                &lua_script,
+                fs_changes,
+                root_dir,
+                worktree_id,
+                workspace,
+                foreground_tx,
+            )
+            .await
+            .map_err(|err| anyhow!(format!("{err}")))?;
             let output = output.printed_lines.join("\n");
 
             Ok(format!("The script output the following:\n{output}"))
@@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into()
     }
 }
 
+struct ForegroundFn(Box<dyn FnOnce(AsyncApp) + Send>);
+
+async fn run_foreground_fn<R: Send + 'static>(
+    description: &str,
+    foreground_tx: &mut mpsc::Sender<ForegroundFn>,
+    function: Box<dyn FnOnce(AsyncApp) -> anyhow::Result<R> + Send>,
+) -> Result<R> {
+    let (response_tx, response_rx) = oneshot::channel();
+    let send_result = foreground_tx
+        .send(ForegroundFn(Box::new(move |cx| {
+            response_tx.send(function(cx)).ok();
+        })))
+        .await;
+    match send_result {
+        Ok(()) => (),
+        Err(err) => {
+            return Err(mlua::Error::runtime(format!(
+                "Internal error while enqueuing work for {description}: {err}"
+            )))
+        }
+    }
+    match response_rx.await {
+        Ok(Ok(result)) => Ok(result),
+        Ok(Err(err)) => Err(mlua::Error::runtime(format!(
+            "Error while {description}: {err}"
+        ))),
+        Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
+            "Internal error: response oneshot was canceled while {description}."
+        ))),
+    }
+}
+
 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
 
 struct FileContent(RefCell<Vec<u8>>);
@@ -103,7 +158,7 @@ impl UserData for FileContent {
 }
 
 /// Sandboxed print() function in Lua.
-fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
+fn print(lua: &Lua, printed_lines: Arc<Mutex<Vec<String>>>) -> Result<Function> {
     lua.create_function(move |_, args: MultiValue| {
         let mut string = String::new();
 
@@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
             string.push_str(arg.to_string()?.as_str())
         }
 
-        printed_lines.borrow_mut().push(string);
+        printed_lines.lock().push(string);
 
         Ok(())
     })
@@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
 
 fn search(
     lua: &Lua,
-    _fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
+    _fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
     root_dir: PathBuf,
+    worktree_id: WorktreeId,
+    workspace: WeakEntity<Workspace>,
+    foreground_tx: mpsc::Sender<ForegroundFn>,
 ) -> Result<Function> {
-    lua.create_function(move |lua, regex: String| {
-        use mlua::Table;
-        use regex::Regex;
-        use std::fs;
-
-        // Function to recursively search directory
-        let search_regex = match Regex::new(&regex) {
-            Ok(re) => re,
-            Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
-        };
+    lua.create_async_function(move |lua, regex: String| {
+        let root_dir = root_dir.clone();
+        let workspace = workspace.clone();
+        let mut foreground_tx = foreground_tx.clone();
+        async move {
+            use mlua::Table;
+            use regex::Regex;
+            use std::fs;
+
+            // TODO: Allow specification of these options.
+            let search_query = SearchQuery::regex(
+                &regex,
+                false,
+                false,
+                false,
+                PathMatcher::default(),
+                PathMatcher::default(),
+                None,
+            );
+            let search_query = match search_query {
+                Ok(query) => query,
+                Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
+            };
 
-        let mut search_results: Vec<Result<Table>> = Vec::new();
+            // TODO: Should use `search_query.regex`. The tool description should also be updated,
+            // as it specifies standard regex.
+            let search_regex = match Regex::new(&regex) {
+                Ok(re) => re,
+                Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
+            };
 
-        // Create an explicit stack for directories to process
-        let mut dir_stack = vec![root_dir.clone()];
+            let project_path_rx =
+                find_search_candidates(search_query, workspace, &mut foreground_tx).await?;
 
-        while let Some(current_dir) = dir_stack.pop() {
-            // Process each entry in the current directory
-            let entries = match fs::read_dir(&current_dir) {
-                Ok(entries) => entries,
-                Err(e) => return Err(e.into()),
-            };
+            let mut search_results: Vec<Result<Table>> = Vec::new();
+            while let Ok(project_path) = project_path_rx.recv().await {
+                if project_path.worktree_id != worktree_id {
+                    continue;
+                }
 
-            for entry_result in entries {
-                let entry = match entry_result {
-                    Ok(e) => e,
-                    Err(e) => return Err(e.into()),
-                };
-
-                let path = entry.path();
-
-                if path.is_dir() {
-                    // Skip .git directory and other common directories to ignore
-                    let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
-                    if !dir_name.starts_with('.')
-                        && dir_name != "node_modules"
-                        && dir_name != "target"
-                    {
-                        // Instead of recursive call, add to stack
-                        dir_stack.push(path);
-                    }
-                } else if path.is_file() {
-                    // Skip binary files and very large files
-                    if let Ok(metadata) = fs::metadata(&path) {
-                        if metadata.len() > 1_000_000 {
-                            // Skip files larger than 1MB
-                            continue;
-                        }
+                let path = root_dir.join(project_path.path);
+
+                // Skip files larger than 1MB
+                if let Ok(metadata) = fs::metadata(&path) {
+                    if metadata.len() > 1_000_000 {
+                        continue;
                     }
+                }
 
-                    // Attempt to read the file as text
-                    if let Ok(content) = fs::read_to_string(&path) {
-                        let mut matches = Vec::new();
+                // Attempt to read the file as text
+                if let Ok(content) = fs::read_to_string(&path) {
+                    let mut matches = Vec::new();
 
-                        // Find all regex matches in the content
-                        for capture in search_regex.find_iter(&content) {
-                            matches.push(capture.as_str().to_string());
-                        }
-
-                        // If we found matches, create a result entry
-                        if !matches.is_empty() {
-                            let result_entry = lua.create_table()?;
-                            result_entry.set("path", path.to_string_lossy().to_string())?;
+                    // Find all regex matches in the content
+                    for capture in search_regex.find_iter(&content) {
+                        matches.push(capture.as_str().to_string());
+                    }
 
-                            let matches_table = lua.create_table()?;
-                            for (i, m) in matches.iter().enumerate() {
-                                matches_table.set(i + 1, m.clone())?;
-                            }
-                            result_entry.set("matches", matches_table)?;
+                    // If we found matches, create a result entry
+                    if !matches.is_empty() {
+                        let result_entry = lua.create_table()?;
+                        result_entry.set("path", path.to_string_lossy().to_string())?;
 
-                            search_results.push(Ok(result_entry));
+                        let matches_table = lua.create_table()?;
+                        for (i, m) in matches.iter().enumerate() {
+                            matches_table.set(i + 1, m.clone())?;
                         }
+                        result_entry.set("matches", matches_table)?;
+
+                        search_results.push(Ok(result_entry));
                     }
                 }
             }
-        }
 
-        // Create a table to hold our results
-        let results_table = lua.create_table()?;
-        for (i, result) in search_results.into_iter().enumerate() {
-            match result {
-                Ok(entry) => results_table.set(i + 1, entry)?,
-                Err(e) => return Err(e),
+            // Create a table to hold our results
+            let results_table = lua.create_table()?;
+            for (i, result) in search_results.into_iter().enumerate() {
+                match result {
+                    Ok(entry) => results_table.set(i + 1, entry)?,
+                    Err(e) => return Err(e),
+                }
             }
-        }
 
-        Ok(results_table)
+            Ok(results_table)
+        }
     })
 }
 
+async fn find_search_candidates(
+    search_query: SearchQuery,
+    workspace: WeakEntity<Workspace>,
+    foreground_tx: &mut mpsc::Sender<ForegroundFn>,
+) -> Result<smol::channel::Receiver<ProjectPath>> {
+    run_foreground_fn(
+        "finding search file candidates",
+        foreground_tx,
+        Box::new(move |mut cx| {
+            workspace.update(&mut cx, move |workspace, cx| {
+                workspace.project().update(cx, |project, cx| {
+                    project.worktree_store().update(cx, |worktree_store, cx| {
+                        // TODO: Better limit? For now this is the same as
+                        // MAX_SEARCH_RESULT_FILES.
+                        let limit = 5000;
+                        // TODO: Providing non-empty open_entries can make this a bit more
+                        // efficient as it can skip checking that these paths are textual.
+                        let open_entries = HashSet::default();
+                        // TODO: This is searching all worktrees, but should only search the
+                        // current worktree
+                        worktree_store.find_search_candidates(
+                            search_query,
+                            limit,
+                            open_entries,
+                            project.fs().clone(),
+                            cx,
+                        )
+                    })
+                })
+            })
+        }),
+    )
+    .await
+}
+
 /// Sandboxed io.open() function in Lua.
 fn io_open(
     lua: &Lua,
-    fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
+    fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
     root_dir: PathBuf,
 ) -> Result<Function> {
     lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@@ -281,7 +372,7 @@ fn io_open(
                     // Don't actually write to disk; instead, just update fs_changes.
                     let path_buf = PathBuf::from(&path);
                     fs_changes
-                        .borrow_mut()
+                        .lock()
                         .insert(path_buf.clone(), content_vec.clone());
                 }
 
@@ -333,13 +424,15 @@ fn io_open(
             return Ok((Some(file), String::new()));
         }
 
-        let is_in_changes = fs_changes.borrow().contains_key(&path);
+        let fs_changes_map = fs_changes.lock();
+
+        let is_in_changes = fs_changes_map.contains_key(&path);
         let file_exists = is_in_changes || path.exists();
         let mut file_content = Vec::new();
 
         if file_exists && !truncate {
             if is_in_changes {
-                file_content = fs_changes.borrow().get(&path).unwrap().clone();
+                file_content = fs_changes_map.get(&path).unwrap().clone();
             } else {
                 // Try to read existing content if file exists and we're not truncating
                 match std::fs::read(&path) {
@@ -349,6 +442,8 @@ fn io_open(
             }
         }
 
+        drop(fs_changes_map); // Unlock the fs_changes mutex.
+
         // If in append mode, position should be at the end
         let position = if append && file_exists {
             file_content.len()
@@ -582,9 +677,7 @@ fn io_open(
                 // Update fs_changes
                 let path = file_userdata.get::<String>("__path")?;
                 let path_buf = PathBuf::from(path);
-                fs_changes
-                    .borrow_mut()
-                    .insert(path_buf, content_vec.clone());
+                fs_changes.lock().insert(path_buf, content_vec.clone());
 
                 Ok(true)
             })?
@@ -597,33 +690,46 @@ fn io_open(
 }
 
 /// Runs a Lua script in a sandboxed environment and returns the printed lines
-pub fn run_sandboxed_lua(
+async fn run_sandboxed_lua(
     script: &str,
     fs_changes: HashMap<PathBuf, Vec<u8>>,
     root_dir: PathBuf,
+    worktree_id: WorktreeId,
+    workspace: WeakEntity<Workspace>,
+    foreground_tx: mpsc::Sender<ForegroundFn>,
 ) -> Result<ScriptOutput> {
     let lua = Lua::new();
     lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
     let globals = lua.globals();
 
     // Track the lines the Lua script prints out.
-    let printed_lines = Rc::new(RefCell::new(Vec::new()));
-    let fs = Rc::new(RefCell::new(fs_changes));
+    let printed_lines = Arc::new(Mutex::new(Vec::new()));
+    let fs = Arc::new(Mutex::new(fs_changes));
 
     globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
-    globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
+    globals.set(
+        "search",
+        search(
+            &lua,
+            fs.clone(),
+            root_dir.clone(),
+            worktree_id,
+            workspace,
+            foreground_tx,
+        )?,
+    )?;
     globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
     globals.set("user_script", script)?;
 
-    lua.load(SANDBOX_PREAMBLE).exec()?;
+    lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 
-    drop(lua); // Necessary so the Rc'd values get decremented.
+    drop(lua); // Decrements the Arcs so that try_unwrap works.
 
     Ok(ScriptOutput {
-        printed_lines: Rc::try_unwrap(printed_lines)
+        printed_lines: Arc::try_unwrap(printed_lines)
             .expect("There are still other references to printed_lines")
             .into_inner(),
-        fs_changes: Rc::try_unwrap(fs)
+        fs_changes: Arc::try_unwrap(fs)
             .expect("There are still other references to fs_changes")
             .into_inner(),
     })