1use anyhow::Result;
2use collections::{HashMap, HashSet};
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
8use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
9use parking_lot::Mutex;
10use project::{search::SearchQuery, Fs, Project};
11use regex::Regex;
12use std::{
13 cell::RefCell,
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17use util::{paths::PathMatcher, ResultExt};
18
19pub struct ScriptOutput {
20 pub stdout: String,
21}
22
23struct ForegroundFn(Box<dyn FnOnce(WeakEntity<Session>, AsyncApp) + Send>);
24
25pub struct Session {
26 project: Entity<Project>,
27 // TODO Remove this
28 fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
29 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
30 _invoke_foreground_fns: Task<()>,
31}
32
33impl Session {
34 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
35 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
36 Session {
37 project,
38 fs_changes: Arc::new(Mutex::new(HashMap::default())),
39 foreground_fns_tx,
40 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
41 while let Some(foreground_fn) = foreground_fns_rx.next().await {
42 foreground_fn.0(this.clone(), cx.clone());
43 }
44 }),
45 }
46 }
47
48 /// Runs a Lua script in a sandboxed environment and returns the printed lines
49 pub fn run_script(
50 &mut self,
51 script: String,
52 cx: &mut Context<Self>,
53 ) -> Task<Result<ScriptOutput>> {
54 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
55
56 // TODO Remove fs_changes
57 let fs_changes = self.fs_changes.clone();
58 // TODO Honor all worktrees instead of the first one
59 let root_dir = self
60 .project
61 .read(cx)
62 .visible_worktrees(cx)
63 .next()
64 .map(|worktree| worktree.read(cx).abs_path());
65 let fs = self.project.read(cx).fs().clone();
66 let foreground_fns_tx = self.foreground_fns_tx.clone();
67 cx.background_spawn(async move {
68 let lua = Lua::new();
69 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
70 let globals = lua.globals();
71 let stdout = Arc::new(Mutex::new(String::new()));
72 globals.set(
73 "sb_print",
74 lua.create_function({
75 let stdout = stdout.clone();
76 move |_, args: MultiValue| Self::print(args, &stdout)
77 })?,
78 )?;
79 globals.set(
80 "search",
81 lua.create_async_function({
82 let foreground_fns_tx = foreground_fns_tx.clone();
83 let fs = fs.clone();
84 move |lua, regex| {
85 Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
86 }
87 })?,
88 )?;
89 globals.set(
90 "sb_io_open",
91 lua.create_function({
92 let fs_changes = fs_changes.clone();
93 let root_dir = root_dir.clone();
94 move |lua, (path_str, mode)| {
95 Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
96 }
97 })?,
98 )?;
99 globals.set("user_script", script)?;
100
101 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
102
103 // Drop Lua instance to decrement reference count.
104 drop(lua);
105
106 let stdout = Arc::try_unwrap(stdout)
107 .expect("no more references to stdout")
108 .into_inner();
109 Ok(ScriptOutput { stdout })
110 })
111 }
112
113 /// Sandboxed print() function in Lua.
114 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
115 for (index, arg) in args.into_iter().enumerate() {
116 // Lua's `print()` prints tab characters between each argument.
117 if index > 0 {
118 stdout.lock().push('\t');
119 }
120
121 // If the argument's to_string() fails, have the whole function call fail.
122 stdout.lock().push_str(&arg.to_string()?);
123 }
124 stdout.lock().push('\n');
125
126 Ok(())
127 }
128
129 /// Sandboxed io.open() function in Lua.
130 fn io_open(
131 lua: &Lua,
132 fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
133 root_dir: Option<&Arc<Path>>,
134 path_str: String,
135 mode: Option<String>,
136 ) -> mlua::Result<(Option<Table>, String)> {
137 let root_dir = root_dir
138 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
139
140 let mode = mode.unwrap_or_else(|| "r".to_string());
141
142 // Parse the mode string to determine read/write permissions
143 let read_perm = mode.contains('r');
144 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
145 let append = mode.contains('a');
146 let truncate = mode.contains('w');
147
148 // This will be the Lua value returned from the `open` function.
149 let file = lua.create_table()?;
150
151 // Store file metadata in the file
152 file.set("__path", path_str.clone())?;
153 file.set("__mode", mode.clone())?;
154 file.set("__read_perm", read_perm)?;
155 file.set("__write_perm", write_perm)?;
156
157 // Sandbox the path; it must be within root_dir
158 let path: PathBuf = {
159 let rust_path = Path::new(&path_str);
160
161 // Get absolute path
162 if rust_path.is_absolute() {
163 // Check if path starts with root_dir prefix without resolving symlinks
164 if !rust_path.starts_with(&root_dir) {
165 return Ok((
166 None,
167 format!(
168 "Error: Absolute path {} is outside the current working directory",
169 path_str
170 ),
171 ));
172 }
173 rust_path.to_path_buf()
174 } else {
175 // Make relative path absolute relative to cwd
176 root_dir.join(rust_path)
177 }
178 };
179
180 // close method
181 let close_fn = {
182 let fs_changes = fs_changes.clone();
183 lua.create_function(move |_lua, file_userdata: mlua::Table| {
184 let write_perm = file_userdata.get::<bool>("__write_perm")?;
185 let path = file_userdata.get::<String>("__path")?;
186
187 if write_perm {
188 // When closing a writable file, record the content
189 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
190 let content_ref = content.borrow::<FileContent>()?;
191 let content_vec = content_ref.0.borrow();
192
193 // Don't actually write to disk; instead, just update fs_changes.
194 let path_buf = PathBuf::from(&path);
195 fs_changes
196 .lock()
197 .insert(path_buf.clone(), content_vec.clone());
198 }
199
200 Ok(true)
201 })?
202 };
203 file.set("close", close_fn)?;
204
205 // If it's a directory, give it a custom read() and return early.
206 if path.is_dir() {
207 // TODO handle the case where we changed it in the in-memory fs
208
209 // Create a special directory handle
210 file.set("__is_directory", true)?;
211
212 // Store directory entries
213 let entries = match std::fs::read_dir(&path) {
214 Ok(entries) => {
215 let mut entry_names = Vec::new();
216 for entry in entries.flatten() {
217 entry_names.push(entry.file_name().to_string_lossy().into_owned());
218 }
219 entry_names
220 }
221 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
222 };
223
224 // Save the list of entries
225 file.set("__dir_entries", entries)?;
226 file.set("__dir_position", 0usize)?;
227
228 // Create a directory-specific read function
229 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
230 let position = file_userdata.get::<usize>("__dir_position")?;
231 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
232
233 if position >= entries.len() {
234 return Ok(None); // No more entries
235 }
236
237 let entry = entries[position].clone();
238 file_userdata.set("__dir_position", position + 1)?;
239
240 Ok(Some(entry))
241 })?;
242 file.set("read", read_fn)?;
243
244 // If we got this far, the directory was opened successfully
245 return Ok((Some(file), String::new()));
246 }
247
248 let fs_changes_map = fs_changes.lock();
249
250 let is_in_changes = fs_changes_map.contains_key(&path);
251 let file_exists = is_in_changes || path.exists();
252 let mut file_content = Vec::new();
253
254 if file_exists && !truncate {
255 if is_in_changes {
256 file_content = fs_changes_map.get(&path).unwrap().clone();
257 } else {
258 // Try to read existing content if file exists and we're not truncating
259 match std::fs::read(&path) {
260 Ok(content) => file_content = content,
261 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
262 }
263 }
264 }
265
266 drop(fs_changes_map); // Unlock the fs_changes mutex.
267
268 // If in append mode, position should be at the end
269 let position = if append && file_exists {
270 file_content.len()
271 } else {
272 0
273 };
274 file.set("__position", position)?;
275 file.set(
276 "__content",
277 lua.create_userdata(FileContent(RefCell::new(file_content)))?,
278 )?;
279
280 // Create file methods
281
282 // read method
283 let read_fn = {
284 lua.create_function(
285 |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
286 let read_perm = file_userdata.get::<bool>("__read_perm")?;
287 if !read_perm {
288 return Err(mlua::Error::runtime("File not open for reading"));
289 }
290
291 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
292 let mut position = file_userdata.get::<usize>("__position")?;
293 let content_ref = content.borrow::<FileContent>()?;
294 let content_vec = content_ref.0.borrow();
295
296 if position >= content_vec.len() {
297 return Ok(None); // EOF
298 }
299
300 match format {
301 Some(mlua::Value::String(s)) => {
302 let lossy_string = s.to_string_lossy();
303 let format_str: &str = lossy_string.as_ref();
304
305 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
306 match &format_str[0..2] {
307 "*a" => {
308 // Read entire file from current position
309 let result = String::from_utf8_lossy(&content_vec[position..])
310 .to_string();
311 position = content_vec.len();
312 file_userdata.set("__position", position)?;
313 Ok(Some(result))
314 }
315 "*l" => {
316 // Read next line
317 let mut line = Vec::new();
318 let mut found_newline = false;
319
320 while position < content_vec.len() {
321 let byte = content_vec[position];
322 position += 1;
323
324 if byte == b'\n' {
325 found_newline = true;
326 break;
327 }
328
329 // Skip \r in \r\n sequence but add it if it's alone
330 if byte == b'\r' {
331 if position < content_vec.len()
332 && content_vec[position] == b'\n'
333 {
334 position += 1;
335 found_newline = true;
336 break;
337 }
338 }
339
340 line.push(byte);
341 }
342
343 file_userdata.set("__position", position)?;
344
345 if !found_newline
346 && line.is_empty()
347 && position >= content_vec.len()
348 {
349 return Ok(None); // EOF
350 }
351
352 let result = String::from_utf8_lossy(&line).to_string();
353 Ok(Some(result))
354 }
355 "*n" => {
356 // Try to parse as a number (number of bytes to read)
357 match format_str.parse::<usize>() {
358 Ok(n) => {
359 let end =
360 std::cmp::min(position + n, content_vec.len());
361 let bytes = &content_vec[position..end];
362 let result = String::from_utf8_lossy(bytes).to_string();
363 position = end;
364 file_userdata.set("__position", position)?;
365 Ok(Some(result))
366 }
367 Err(_) => Err(mlua::Error::runtime(format!(
368 "Invalid format: {}",
369 format_str
370 ))),
371 }
372 }
373 "*L" => {
374 // Read next line keeping the end of line
375 let mut line = Vec::new();
376
377 while position < content_vec.len() {
378 let byte = content_vec[position];
379 position += 1;
380
381 line.push(byte);
382
383 if byte == b'\n' {
384 break;
385 }
386
387 // If we encounter a \r, add it and check if the next is \n
388 if byte == b'\r'
389 && position < content_vec.len()
390 && content_vec[position] == b'\n'
391 {
392 line.push(content_vec[position]);
393 position += 1;
394 break;
395 }
396 }
397
398 file_userdata.set("__position", position)?;
399
400 if line.is_empty() && position >= content_vec.len() {
401 return Ok(None); // EOF
402 }
403
404 let result = String::from_utf8_lossy(&line).to_string();
405 Ok(Some(result))
406 }
407 _ => Err(mlua::Error::runtime(format!(
408 "Unsupported format: {}",
409 format_str
410 ))),
411 }
412 }
413 Some(mlua::Value::Number(n)) => {
414 // Read n bytes
415 let n = n as usize;
416 let end = std::cmp::min(position + n, content_vec.len());
417 let bytes = &content_vec[position..end];
418 let result = String::from_utf8_lossy(bytes).to_string();
419 position = end;
420 file_userdata.set("__position", position)?;
421 Ok(Some(result))
422 }
423 Some(_) => Err(mlua::Error::runtime("Invalid format")),
424 None => {
425 // Default is to read a line
426 let mut line = Vec::new();
427 let mut found_newline = false;
428
429 while position < content_vec.len() {
430 let byte = content_vec[position];
431 position += 1;
432
433 if byte == b'\n' {
434 found_newline = true;
435 break;
436 }
437
438 // Handle \r\n
439 if byte == b'\r' {
440 if position < content_vec.len()
441 && content_vec[position] == b'\n'
442 {
443 position += 1;
444 found_newline = true;
445 break;
446 }
447 }
448
449 line.push(byte);
450 }
451
452 file_userdata.set("__position", position)?;
453
454 if !found_newline && line.is_empty() && position >= content_vec.len() {
455 return Ok(None); // EOF
456 }
457
458 let result = String::from_utf8_lossy(&line).to_string();
459 Ok(Some(result))
460 }
461 }
462 },
463 )?
464 };
465 file.set("read", read_fn)?;
466
467 // write method
468 let write_fn = {
469 let fs_changes = fs_changes.clone();
470
471 lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
472 let write_perm = file_userdata.get::<bool>("__write_perm")?;
473 if !write_perm {
474 return Err(mlua::Error::runtime("File not open for writing"));
475 }
476
477 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
478 let position = file_userdata.get::<usize>("__position")?;
479 let content_ref = content.borrow::<FileContent>()?;
480 let mut content_vec = content_ref.0.borrow_mut();
481
482 let bytes = text.as_bytes();
483
484 // Ensure the vector has enough capacity
485 if position + bytes.len() > content_vec.len() {
486 content_vec.resize(position + bytes.len(), 0);
487 }
488
489 // Write the bytes
490 for (i, &byte) in bytes.iter().enumerate() {
491 content_vec[position + i] = byte;
492 }
493
494 // Update position
495 let new_position = position + bytes.len();
496 file_userdata.set("__position", new_position)?;
497
498 // Update fs_changes
499 let path = file_userdata.get::<String>("__path")?;
500 let path_buf = PathBuf::from(path);
501 fs_changes.lock().insert(path_buf, content_vec.clone());
502
503 Ok(true)
504 })?
505 };
506 file.set("write", write_fn)?;
507
508 // If we got this far, the file was opened successfully
509 Ok((Some(file), String::new()))
510 }
511
512 async fn search(
513 lua: Lua,
514 mut foreground_tx: mpsc::Sender<ForegroundFn>,
515 fs: Arc<dyn Fs>,
516 regex: String,
517 ) -> mlua::Result<Table> {
518 // TODO: Allow specification of these options.
519 let search_query = SearchQuery::regex(
520 ®ex,
521 false,
522 false,
523 false,
524 PathMatcher::default(),
525 PathMatcher::default(),
526 None,
527 );
528 let search_query = match search_query {
529 Ok(query) => query,
530 Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
531 };
532
533 // TODO: Should use `search_query.regex`. The tool description should also be updated,
534 // as it specifies standard regex.
535 let search_regex = match Regex::new(®ex) {
536 Ok(re) => re,
537 Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
538 };
539
540 let mut abs_paths_rx =
541 Self::find_search_candidates(search_query, &mut foreground_tx).await?;
542
543 let mut search_results: Vec<Table> = Vec::new();
544 while let Some(path) = abs_paths_rx.next().await {
545 // Skip files larger than 1MB
546 if let Ok(Some(metadata)) = fs.metadata(&path).await {
547 if metadata.len > 1_000_000 {
548 continue;
549 }
550 }
551
552 // Attempt to read the file as text
553 if let Ok(content) = fs.load(&path).await {
554 let mut matches = Vec::new();
555
556 // Find all regex matches in the content
557 for capture in search_regex.find_iter(&content) {
558 matches.push(capture.as_str().to_string());
559 }
560
561 // If we found matches, create a result entry
562 if !matches.is_empty() {
563 let result_entry = lua.create_table()?;
564 result_entry.set("path", path.to_string_lossy().to_string())?;
565
566 let matches_table = lua.create_table()?;
567 for (ix, m) in matches.iter().enumerate() {
568 matches_table.set(ix + 1, m.clone())?;
569 }
570 result_entry.set("matches", matches_table)?;
571
572 search_results.push(result_entry);
573 }
574 }
575 }
576
577 // Create a table to hold our results
578 let results_table = lua.create_table()?;
579 for (ix, entry) in search_results.into_iter().enumerate() {
580 results_table.set(ix + 1, entry)?;
581 }
582
583 Ok(results_table)
584 }
585
586 async fn find_search_candidates(
587 search_query: SearchQuery,
588 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
589 ) -> mlua::Result<mpsc::UnboundedReceiver<PathBuf>> {
590 Self::run_foreground_fn(
591 "finding search file candidates",
592 foreground_tx,
593 Box::new(move |session, mut cx| {
594 session.update(&mut cx, |session, cx| {
595 session.project.update(cx, |project, cx| {
596 project.worktree_store().update(cx, |worktree_store, cx| {
597 // TODO: Better limit? For now this is the same as
598 // MAX_SEARCH_RESULT_FILES.
599 let limit = 5000;
600 // TODO: Providing non-empty open_entries can make this a bit more
601 // efficient as it can skip checking that these paths are textual.
602 let open_entries = HashSet::default();
603 let candidates = worktree_store.find_search_candidates(
604 search_query,
605 limit,
606 open_entries,
607 project.fs().clone(),
608 cx,
609 );
610 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
611 cx.spawn(|worktree_store, cx| async move {
612 pin_mut!(candidates);
613
614 while let Some(project_path) = candidates.next().await {
615 worktree_store.read_with(&cx, |worktree_store, cx| {
616 if let Some(worktree) = worktree_store
617 .worktree_for_id(project_path.worktree_id, cx)
618 {
619 if let Some(abs_path) = worktree
620 .read(cx)
621 .absolutize(&project_path.path)
622 .log_err()
623 {
624 abs_paths_tx.unbounded_send(abs_path)?;
625 }
626 }
627 anyhow::Ok(())
628 })??;
629 }
630 anyhow::Ok(())
631 })
632 .detach();
633 abs_paths_rx
634 })
635 })
636 })
637 }),
638 )
639 .await
640 }
641
642 async fn run_foreground_fn<R: Send + 'static>(
643 description: &str,
644 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
645 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> anyhow::Result<R> + Send>,
646 ) -> mlua::Result<R> {
647 let (response_tx, response_rx) = oneshot::channel();
648 let send_result = foreground_tx
649 .send(ForegroundFn(Box::new(move |this, cx| {
650 response_tx.send(function(this, cx)).ok();
651 })))
652 .await;
653 match send_result {
654 Ok(()) => (),
655 Err(err) => {
656 return Err(mlua::Error::runtime(format!(
657 "Internal error while enqueuing work for {description}: {err}"
658 )))
659 }
660 }
661 match response_rx.await {
662 Ok(Ok(result)) => Ok(result),
663 Ok(Err(err)) => Err(mlua::Error::runtime(format!(
664 "Error while {description}: {err}"
665 ))),
666 Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
667 "Internal error: response oneshot was canceled while {description}."
668 ))),
669 }
670 }
671}
672
673struct FileContent(RefCell<Vec<u8>>);
674
675impl UserData for FileContent {
676 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
677 // FileContent doesn't have any methods so far.
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use gpui::TestAppContext;
684 use project::FakeFs;
685 use serde_json::json;
686 use settings::SettingsStore;
687
688 use super::*;
689
690 #[gpui::test]
691 async fn test_print(cx: &mut TestAppContext) {
692 init_test(cx);
693 let fs = FakeFs::new(cx.executor());
694 let project = Project::test(fs, [], cx).await;
695 let session = cx.new(|cx| Session::new(project, cx));
696 let script = r#"
697 print("Hello", "world!")
698 print("Goodbye", "moon!")
699 "#;
700 let output = session
701 .update(cx, |session, cx| session.run_script(script.to_string(), cx))
702 .await
703 .unwrap();
704 assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n");
705 }
706
707 #[gpui::test]
708 async fn test_search(cx: &mut TestAppContext) {
709 init_test(cx);
710 let fs = FakeFs::new(cx.executor());
711 fs.insert_tree(
712 "/",
713 json!({
714 "file1.txt": "Hello world!",
715 "file2.txt": "Goodbye moon!"
716 }),
717 )
718 .await;
719 let project = Project::test(fs, [Path::new("/")], cx).await;
720 let session = cx.new(|cx| Session::new(project, cx));
721 let script = r#"
722 local results = search("world")
723 for i, result in ipairs(results) do
724 print("File: " .. result.path)
725 print("Matches:")
726 for j, match in ipairs(result.matches) do
727 print(" " .. match)
728 end
729 end
730 "#;
731 let output = session
732 .update(cx, |session, cx| session.run_script(script.to_string(), cx))
733 .await
734 .unwrap();
735 assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n world\n");
736 }
737
738 fn init_test(cx: &mut TestAppContext) {
739 let settings_store = cx.update(SettingsStore::test);
740 cx.set_global(settings_store);
741 cx.update(Project::init_settings);
742 }
743}