1use anyhow::anyhow;
2use assistant_tool::{Tool, ToolRegistry};
3use gpui::{App, AppContext as _, Task, WeakEntity, Window};
4use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
5use schemars::JsonSchema;
6use serde::Deserialize;
7use std::{
8 cell::RefCell,
9 collections::HashMap,
10 path::{Path, PathBuf},
11 rc::Rc,
12 sync::Arc,
13};
14use workspace::Workspace;
15
16pub fn init(cx: &App) {
17 let registry = ToolRegistry::global(cx);
18 registry.register_tool(ScriptingTool);
19}
20
21#[derive(Debug, Deserialize, JsonSchema)]
22struct ScriptingToolInput {
23 lua_script: String,
24}
25
26struct ScriptingTool;
27
28impl Tool for ScriptingTool {
29 fn name(&self) -> String {
30 "lua-interpreter".into()
31 }
32
33 fn description(&self) -> String {
34 r#"You can write a Lua script and I'll run it on my code base and tell you what its output was,
35including both stdout as well as the git diff of changes it made to the filesystem. That way,
36you can get more information about the code base, or make changes to the code base directly.
37The lua script will have access to `io` and it will run with the current working directory being in
38the root of the code base, so you can use it to explore, search, make changes, etc. You can also have
39the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and
40then the file it returns only has the methods read, write, and close - it doesn't have popen or
41anything else. Also, I'm going to be putting this Lua script into JSON, so please don't use Lua's
42double quote syntax for string literals - use one of Lua's other syntaxes for string literals, so I
43don't have to escape the double quotes. There will be a global called `search` which accepts a regex
44(it's implemented using Rust's regex crate, so use that regex syntax) and runs that regex on the contents
45of every file in the code base (aside from gitignored files), then returns an array of tables with two
46fields: "path" (the path to the file that had the matches) and "matches" (an array of strings, with each
47string being a match that was found within the file)."#.into()
48 }
49
50 fn input_schema(&self) -> serde_json::Value {
51 let schema = schemars::schema_for!(ScriptingToolInput);
52 serde_json::to_value(&schema).unwrap()
53 }
54
55 fn run(
56 self: Arc<Self>,
57 input: serde_json::Value,
58 workspace: WeakEntity<Workspace>,
59 _window: &mut Window,
60 cx: &mut App,
61 ) -> Task<anyhow::Result<String>> {
62 let root_dir = workspace.update(cx, |workspace, cx| {
63 let first_worktree = workspace
64 .visible_worktrees(cx)
65 .next()
66 .ok_or_else(|| anyhow!("no worktrees"))?;
67 workspace
68 .absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
69 .ok_or_else(|| anyhow!("no worktree root"))
70 });
71 let root_dir = match root_dir {
72 Ok(root_dir) => root_dir,
73 Err(err) => return Task::ready(Err(err)),
74 };
75 let root_dir = match root_dir {
76 Ok(root_dir) => root_dir,
77 Err(err) => return Task::ready(Err(err)),
78 };
79 let input = match serde_json::from_value::<ScriptingToolInput>(input) {
80 Err(err) => return Task::ready(Err(err.into())),
81 Ok(input) => input,
82 };
83 let lua_script = input.lua_script;
84 cx.background_spawn(async move {
85 let fs_changes = HashMap::new();
86 let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
87 .map_err(|err| anyhow!(format!("{err}")))?;
88 let output = output.printed_lines.join("\n");
89
90 Ok(format!("The script output the following:\n{output}"))
91 })
92 }
93}
94
95const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
96
97struct FileContent(RefCell<Vec<u8>>);
98
99impl UserData for FileContent {
100 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
101 // FileContent doesn't have any methods so far.
102 }
103}
104
105/// Sandboxed print() function in Lua.
106fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
107 lua.create_function(move |_, args: MultiValue| {
108 let mut string = String::new();
109
110 for arg in args.into_iter() {
111 // Lua's `print()` prints tab characters between each argument.
112 if !string.is_empty() {
113 string.push('\t');
114 }
115
116 // If the argument's to_string() fails, have the whole function call fail.
117 string.push_str(arg.to_string()?.as_str())
118 }
119
120 printed_lines.borrow_mut().push(string);
121
122 Ok(())
123 })
124}
125
126fn search(
127 lua: &Lua,
128 _fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
129 root_dir: PathBuf,
130) -> Result<Function> {
131 lua.create_function(move |lua, regex: String| {
132 use mlua::Table;
133 use regex::Regex;
134 use std::fs;
135
136 // Function to recursively search directory
137 let search_regex = match Regex::new(®ex) {
138 Ok(re) => re,
139 Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
140 };
141
142 let mut search_results: Vec<Result<Table>> = Vec::new();
143
144 // Create an explicit stack for directories to process
145 let mut dir_stack = vec![root_dir.clone()];
146
147 while let Some(current_dir) = dir_stack.pop() {
148 // Process each entry in the current directory
149 let entries = match fs::read_dir(¤t_dir) {
150 Ok(entries) => entries,
151 Err(e) => return Err(e.into()),
152 };
153
154 for entry_result in entries {
155 let entry = match entry_result {
156 Ok(e) => e,
157 Err(e) => return Err(e.into()),
158 };
159
160 let path = entry.path();
161
162 if path.is_dir() {
163 // Skip .git directory and other common directories to ignore
164 let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
165 if !dir_name.starts_with('.')
166 && dir_name != "node_modules"
167 && dir_name != "target"
168 {
169 // Instead of recursive call, add to stack
170 dir_stack.push(path);
171 }
172 } else if path.is_file() {
173 // Skip binary files and very large files
174 if let Ok(metadata) = fs::metadata(&path) {
175 if metadata.len() > 1_000_000 {
176 // Skip files larger than 1MB
177 continue;
178 }
179 }
180
181 // Attempt to read the file as text
182 if let Ok(content) = fs::read_to_string(&path) {
183 let mut matches = Vec::new();
184
185 // Find all regex matches in the content
186 for capture in search_regex.find_iter(&content) {
187 matches.push(capture.as_str().to_string());
188 }
189
190 // If we found matches, create a result entry
191 if !matches.is_empty() {
192 let result_entry = lua.create_table()?;
193 result_entry.set("path", path.to_string_lossy().to_string())?;
194
195 let matches_table = lua.create_table()?;
196 for (i, m) in matches.iter().enumerate() {
197 matches_table.set(i + 1, m.clone())?;
198 }
199 result_entry.set("matches", matches_table)?;
200
201 search_results.push(Ok(result_entry));
202 }
203 }
204 }
205 }
206 }
207
208 // Create a table to hold our results
209 let results_table = lua.create_table()?;
210 for (i, result) in search_results.into_iter().enumerate() {
211 match result {
212 Ok(entry) => results_table.set(i + 1, entry)?,
213 Err(e) => return Err(e),
214 }
215 }
216
217 Ok(results_table)
218 })
219}
220
221/// Sandboxed io.open() function in Lua.
222fn io_open(
223 lua: &Lua,
224 fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
225 root_dir: PathBuf,
226) -> Result<Function> {
227 lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
228 let mode = mode.unwrap_or_else(|| "r".to_string());
229
230 // Parse the mode string to determine read/write permissions
231 let read_perm = mode.contains('r');
232 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
233 let append = mode.contains('a');
234 let truncate = mode.contains('w');
235
236 // This will be the Lua value returned from the `open` function.
237 let file = lua.create_table()?;
238
239 // Store file metadata in the file
240 file.set("__path", path_str.clone())?;
241 file.set("__mode", mode.clone())?;
242 file.set("__read_perm", read_perm)?;
243 file.set("__write_perm", write_perm)?;
244
245 // Sandbox the path; it must be within root_dir
246 let path: PathBuf = {
247 let rust_path = Path::new(&path_str);
248
249 // Get absolute path
250 if rust_path.is_absolute() {
251 // Check if path starts with root_dir prefix without resolving symlinks
252 if !rust_path.starts_with(&root_dir) {
253 return Ok((
254 None,
255 format!(
256 "Error: Absolute path {} is outside the current working directory",
257 path_str
258 ),
259 ));
260 }
261 rust_path.to_path_buf()
262 } else {
263 // Make relative path absolute relative to cwd
264 root_dir.join(rust_path)
265 }
266 };
267
268 // close method
269 let close_fn = {
270 let fs_changes = fs_changes.clone();
271 lua.create_function(move |_lua, file_userdata: mlua::Table| {
272 let write_perm = file_userdata.get::<bool>("__write_perm")?;
273 let path = file_userdata.get::<String>("__path")?;
274
275 if write_perm {
276 // When closing a writable file, record the content
277 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
278 let content_ref = content.borrow::<FileContent>()?;
279 let content_vec = content_ref.0.borrow();
280
281 // Don't actually write to disk; instead, just update fs_changes.
282 let path_buf = PathBuf::from(&path);
283 fs_changes
284 .borrow_mut()
285 .insert(path_buf.clone(), content_vec.clone());
286 }
287
288 Ok(true)
289 })?
290 };
291 file.set("close", close_fn)?;
292
293 // If it's a directory, give it a custom read() and return early.
294 if path.is_dir() {
295 // TODO handle the case where we changed it in the in-memory fs
296
297 // Create a special directory handle
298 file.set("__is_directory", true)?;
299
300 // Store directory entries
301 let entries = match std::fs::read_dir(&path) {
302 Ok(entries) => {
303 let mut entry_names = Vec::new();
304 for entry in entries.flatten() {
305 entry_names.push(entry.file_name().to_string_lossy().into_owned());
306 }
307 entry_names
308 }
309 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
310 };
311
312 // Save the list of entries
313 file.set("__dir_entries", entries)?;
314 file.set("__dir_position", 0usize)?;
315
316 // Create a directory-specific read function
317 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
318 let position = file_userdata.get::<usize>("__dir_position")?;
319 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
320
321 if position >= entries.len() {
322 return Ok(None); // No more entries
323 }
324
325 let entry = entries[position].clone();
326 file_userdata.set("__dir_position", position + 1)?;
327
328 Ok(Some(entry))
329 })?;
330 file.set("read", read_fn)?;
331
332 // If we got this far, the directory was opened successfully
333 return Ok((Some(file), String::new()));
334 }
335
336 let is_in_changes = fs_changes.borrow().contains_key(&path);
337 let file_exists = is_in_changes || path.exists();
338 let mut file_content = Vec::new();
339
340 if file_exists && !truncate {
341 if is_in_changes {
342 file_content = fs_changes.borrow().get(&path).unwrap().clone();
343 } else {
344 // Try to read existing content if file exists and we're not truncating
345 match std::fs::read(&path) {
346 Ok(content) => file_content = content,
347 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
348 }
349 }
350 }
351
352 // If in append mode, position should be at the end
353 let position = if append && file_exists {
354 file_content.len()
355 } else {
356 0
357 };
358 file.set("__position", position)?;
359 file.set(
360 "__content",
361 lua.create_userdata(FileContent(RefCell::new(file_content)))?,
362 )?;
363
364 // Create file methods
365
366 // read method
367 let read_fn = {
368 lua.create_function(
369 |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
370 let read_perm = file_userdata.get::<bool>("__read_perm")?;
371 if !read_perm {
372 return Err(mlua::Error::runtime("File not open for reading"));
373 }
374
375 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
376 let mut position = file_userdata.get::<usize>("__position")?;
377 let content_ref = content.borrow::<FileContent>()?;
378 let content_vec = content_ref.0.borrow();
379
380 if position >= content_vec.len() {
381 return Ok(None); // EOF
382 }
383
384 match format {
385 Some(mlua::Value::String(s)) => {
386 let lossy_string = s.to_string_lossy();
387 let format_str: &str = lossy_string.as_ref();
388
389 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
390 match &format_str[0..2] {
391 "*a" => {
392 // Read entire file from current position
393 let result = String::from_utf8_lossy(&content_vec[position..])
394 .to_string();
395 position = content_vec.len();
396 file_userdata.set("__position", position)?;
397 Ok(Some(result))
398 }
399 "*l" => {
400 // Read next line
401 let mut line = Vec::new();
402 let mut found_newline = false;
403
404 while position < content_vec.len() {
405 let byte = content_vec[position];
406 position += 1;
407
408 if byte == b'\n' {
409 found_newline = true;
410 break;
411 }
412
413 // Skip \r in \r\n sequence but add it if it's alone
414 if byte == b'\r' {
415 if position < content_vec.len()
416 && content_vec[position] == b'\n'
417 {
418 position += 1;
419 found_newline = true;
420 break;
421 }
422 }
423
424 line.push(byte);
425 }
426
427 file_userdata.set("__position", position)?;
428
429 if !found_newline
430 && line.is_empty()
431 && position >= content_vec.len()
432 {
433 return Ok(None); // EOF
434 }
435
436 let result = String::from_utf8_lossy(&line).to_string();
437 Ok(Some(result))
438 }
439 "*n" => {
440 // Try to parse as a number (number of bytes to read)
441 match format_str.parse::<usize>() {
442 Ok(n) => {
443 let end =
444 std::cmp::min(position + n, content_vec.len());
445 let bytes = &content_vec[position..end];
446 let result = String::from_utf8_lossy(bytes).to_string();
447 position = end;
448 file_userdata.set("__position", position)?;
449 Ok(Some(result))
450 }
451 Err(_) => Err(mlua::Error::runtime(format!(
452 "Invalid format: {}",
453 format_str
454 ))),
455 }
456 }
457 "*L" => {
458 // Read next line keeping the end of line
459 let mut line = Vec::new();
460
461 while position < content_vec.len() {
462 let byte = content_vec[position];
463 position += 1;
464
465 line.push(byte);
466
467 if byte == b'\n' {
468 break;
469 }
470
471 // If we encounter a \r, add it and check if the next is \n
472 if byte == b'\r'
473 && position < content_vec.len()
474 && content_vec[position] == b'\n'
475 {
476 line.push(content_vec[position]);
477 position += 1;
478 break;
479 }
480 }
481
482 file_userdata.set("__position", position)?;
483
484 if line.is_empty() && position >= content_vec.len() {
485 return Ok(None); // EOF
486 }
487
488 let result = String::from_utf8_lossy(&line).to_string();
489 Ok(Some(result))
490 }
491 _ => Err(mlua::Error::runtime(format!(
492 "Unsupported format: {}",
493 format_str
494 ))),
495 }
496 }
497 Some(mlua::Value::Number(n)) => {
498 // Read n bytes
499 let n = n as usize;
500 let end = std::cmp::min(position + n, content_vec.len());
501 let bytes = &content_vec[position..end];
502 let result = String::from_utf8_lossy(bytes).to_string();
503 position = end;
504 file_userdata.set("__position", position)?;
505 Ok(Some(result))
506 }
507 Some(_) => Err(mlua::Error::runtime("Invalid format")),
508 None => {
509 // Default is to read a line
510 let mut line = Vec::new();
511 let mut found_newline = false;
512
513 while position < content_vec.len() {
514 let byte = content_vec[position];
515 position += 1;
516
517 if byte == b'\n' {
518 found_newline = true;
519 break;
520 }
521
522 // Handle \r\n
523 if byte == b'\r' {
524 if position < content_vec.len()
525 && content_vec[position] == b'\n'
526 {
527 position += 1;
528 found_newline = true;
529 break;
530 }
531 }
532
533 line.push(byte);
534 }
535
536 file_userdata.set("__position", position)?;
537
538 if !found_newline && line.is_empty() && position >= content_vec.len() {
539 return Ok(None); // EOF
540 }
541
542 let result = String::from_utf8_lossy(&line).to_string();
543 Ok(Some(result))
544 }
545 }
546 },
547 )?
548 };
549 file.set("read", read_fn)?;
550
551 // write method
552 let write_fn = {
553 let fs_changes = fs_changes.clone();
554
555 lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
556 let write_perm = file_userdata.get::<bool>("__write_perm")?;
557 if !write_perm {
558 return Err(mlua::Error::runtime("File not open for writing"));
559 }
560
561 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
562 let position = file_userdata.get::<usize>("__position")?;
563 let content_ref = content.borrow::<FileContent>()?;
564 let mut content_vec = content_ref.0.borrow_mut();
565
566 let bytes = text.as_bytes();
567
568 // Ensure the vector has enough capacity
569 if position + bytes.len() > content_vec.len() {
570 content_vec.resize(position + bytes.len(), 0);
571 }
572
573 // Write the bytes
574 for (i, &byte) in bytes.iter().enumerate() {
575 content_vec[position + i] = byte;
576 }
577
578 // Update position
579 let new_position = position + bytes.len();
580 file_userdata.set("__position", new_position)?;
581
582 // Update fs_changes
583 let path = file_userdata.get::<String>("__path")?;
584 let path_buf = PathBuf::from(path);
585 fs_changes
586 .borrow_mut()
587 .insert(path_buf, content_vec.clone());
588
589 Ok(true)
590 })?
591 };
592 file.set("write", write_fn)?;
593
594 // If we got this far, the file was opened successfully
595 Ok((Some(file), String::new()))
596 })
597}
598
599/// Runs a Lua script in a sandboxed environment and returns the printed lines
600pub fn run_sandboxed_lua(
601 script: &str,
602 fs_changes: HashMap<PathBuf, Vec<u8>>,
603 root_dir: PathBuf,
604) -> Result<ScriptOutput> {
605 let lua = Lua::new();
606 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
607 let globals = lua.globals();
608
609 // Track the lines the Lua script prints out.
610 let printed_lines = Rc::new(RefCell::new(Vec::new()));
611 let fs = Rc::new(RefCell::new(fs_changes));
612
613 globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
614 globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
615 globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
616 globals.set("user_script", script)?;
617
618 lua.load(SANDBOX_PREAMBLE).exec()?;
619
620 drop(lua); // Necessary so the Rc'd values get decremented.
621
622 Ok(ScriptOutput {
623 printed_lines: Rc::try_unwrap(printed_lines)
624 .expect("There are still other references to printed_lines")
625 .into_inner(),
626 fs_changes: Rc::try_unwrap(fs)
627 .expect("There are still other references to fs_changes")
628 .into_inner(),
629 })
630}
631
632pub struct ScriptOutput {
633 printed_lines: Vec<String>,
634 #[allow(dead_code)]
635 fs_changes: HashMap<PathBuf, Vec<u8>>,
636}
637
638#[allow(dead_code)]
639impl ScriptOutput {
640 fn fs_diff(&self) -> HashMap<PathBuf, String> {
641 let mut diff_map = HashMap::new();
642 for (path, content) in &self.fs_changes {
643 let diff = if path.exists() {
644 // Read the current file content
645 match std::fs::read(path) {
646 Ok(current_content) => {
647 // Convert both to strings for diffing
648 let new_content = String::from_utf8_lossy(content).to_string();
649 let old_content = String::from_utf8_lossy(¤t_content).to_string();
650
651 // Generate a git-style diff
652 let new_lines: Vec<&str> = new_content.lines().collect();
653 let old_lines: Vec<&str> = old_content.lines().collect();
654
655 let path_str = path.to_string_lossy();
656 let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str);
657 diff.push_str(&format!("--- a/{}\n", path_str));
658 diff.push_str(&format!("+++ b/{}\n", path_str));
659
660 // Very basic diff algorithm - this is simplified
661 let mut i = 0;
662 let mut j = 0;
663
664 while i < old_lines.len() || j < new_lines.len() {
665 if i < old_lines.len()
666 && j < new_lines.len()
667 && old_lines[i] == new_lines[j]
668 {
669 i += 1;
670 j += 1;
671 continue;
672 }
673
674 // Find next matching line
675 let mut next_i = i;
676 let mut next_j = j;
677 let mut found = false;
678
679 // Look ahead for matches
680 for look_i in i..std::cmp::min(i + 10, old_lines.len()) {
681 for look_j in j..std::cmp::min(j + 10, new_lines.len()) {
682 if old_lines[look_i] == new_lines[look_j] {
683 next_i = look_i;
684 next_j = look_j;
685 found = true;
686 break;
687 }
688 }
689 if found {
690 break;
691 }
692 }
693
694 // Output the hunk header
695 diff.push_str(&format!(
696 "@@ -{},{} +{},{} @@\n",
697 i + 1,
698 if found {
699 next_i - i
700 } else {
701 old_lines.len() - i
702 },
703 j + 1,
704 if found {
705 next_j - j
706 } else {
707 new_lines.len() - j
708 }
709 ));
710
711 // Output removed lines
712 for k in i..next_i {
713 diff.push_str(&format!("-{}\n", old_lines[k]));
714 }
715
716 // Output added lines
717 for k in j..next_j {
718 diff.push_str(&format!("+{}\n", new_lines[k]));
719 }
720
721 i = next_i;
722 j = next_j;
723
724 if found {
725 i += 1;
726 j += 1;
727 } else {
728 break;
729 }
730 }
731
732 diff
733 }
734 Err(_) => format!("Error reading current file: {}", path.display()),
735 }
736 } else {
737 // New file
738 let content_str = String::from_utf8_lossy(content).to_string();
739 let path_str = path.to_string_lossy();
740 let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str);
741 diff.push_str("new file mode 100644\n");
742 diff.push_str("--- /dev/null\n");
743 diff.push_str(&format!("+++ b/{}\n", path_str));
744
745 let lines: Vec<&str> = content_str.lines().collect();
746 diff.push_str(&format!("@@ -0,0 +1,{} @@\n", lines.len()));
747
748 for line in lines {
749 diff.push_str(&format!("+{}\n", line));
750 }
751
752 diff
753 };
754
755 diff_map.insert(path.clone(), diff);
756 }
757
758 diff_map
759 }
760
761 fn diff_to_string(&self) -> String {
762 let mut answer = String::new();
763 let diff_map = self.fs_diff();
764
765 if diff_map.is_empty() {
766 return "No changes to files".to_string();
767 }
768
769 // Sort the paths for consistent output
770 let mut paths: Vec<&PathBuf> = diff_map.keys().collect();
771 paths.sort();
772
773 for path in paths {
774 if !answer.is_empty() {
775 answer.push_str("\n");
776 }
777 answer.push_str(&diff_map[path]);
778 }
779
780 answer
781 }
782}