1use anyhow::anyhow;
2use collections::HashSet;
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
9use parking_lot::Mutex;
10use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
11use regex::Regex;
12use std::{
13 path::{Path, PathBuf},
14 sync::Arc,
15};
16use util::{paths::PathMatcher, ResultExt};
17
18struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
19
20pub struct ScriptingSession {
21 project: Entity<Project>,
22 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
23 _invoke_foreground_fns: Task<()>,
24 scripts: Vec<Script>,
25}
26
27impl ScriptingSession {
28 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
29 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
30 ScriptingSession {
31 project,
32 foreground_fns_tx,
33 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
34 while let Some(foreground_fn) = foreground_fns_rx.next().await {
35 foreground_fn.0(this.clone(), cx.clone());
36 }
37 }),
38 scripts: Vec::new(),
39 }
40 }
41
42 pub fn run_script(
43 &mut self,
44 script_src: String,
45 cx: &mut Context<Self>,
46 ) -> (ScriptId, Task<()>) {
47 let id = ScriptId(self.scripts.len() as u32);
48
49 let stdout = Arc::new(Mutex::new(String::new()));
50
51 let script = Script {
52 state: ScriptState::Running {
53 stdout: stdout.clone(),
54 },
55 };
56 self.scripts.push(script);
57
58 let task = self.run_lua(script_src, stdout, cx);
59
60 let task = cx.spawn(|session, mut cx| async move {
61 let result = task.await;
62
63 session
64 .update(&mut cx, |session, _cx| {
65 let script = session.get_mut(id);
66 let stdout = script.stdout_snapshot();
67
68 script.state = match result {
69 Ok(()) => ScriptState::Succeeded { stdout },
70 Err(error) => ScriptState::Failed { stdout, error },
71 };
72 })
73 .log_err();
74 });
75
76 (id, task)
77 }
78
79 fn run_lua(
80 &mut self,
81 script: String,
82 stdout: Arc<Mutex<String>>,
83 cx: &mut Context<Self>,
84 ) -> Task<anyhow::Result<()>> {
85 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
86
87 // TODO Honor all worktrees instead of the first one
88 let worktree_info = self
89 .project
90 .read(cx)
91 .visible_worktrees(cx)
92 .next()
93 .map(|worktree| {
94 let worktree = worktree.read(cx);
95 (worktree.id(), worktree.abs_path())
96 });
97
98 let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
99
100 let fs = self.project.read(cx).fs().clone();
101 let foreground_fns_tx = self.foreground_fns_tx.clone();
102
103 let task = cx.background_spawn({
104 let stdout = stdout.clone();
105
106 async move {
107 let lua = Lua::new();
108 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
109 let globals = lua.globals();
110
111 // Use the project root dir as the script's current working dir.
112 if let Some(root_dir) = &root_dir {
113 if let Some(root_dir) = root_dir.to_str() {
114 globals.set("cwd", root_dir)?;
115 }
116 }
117
118 globals.set(
119 "sb_print",
120 lua.create_function({
121 let stdout = stdout.clone();
122 move |_, args: MultiValue| Self::print(args, &stdout)
123 })?,
124 )?;
125 globals.set(
126 "search",
127 lua.create_async_function({
128 let foreground_fns_tx = foreground_fns_tx.clone();
129 let fs = fs.clone();
130 move |lua, regex| {
131 let mut foreground_fns_tx = foreground_fns_tx.clone();
132 let fs = fs.clone();
133 async move {
134 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
135 .await
136 .into_lua_err()
137 }
138 }
139 })?,
140 )?;
141 globals.set(
142 "outline",
143 lua.create_async_function({
144 let root_dir = root_dir.clone();
145 let foreground_fns_tx = foreground_fns_tx.clone();
146 move |_lua, path| {
147 let mut foreground_fns_tx = foreground_fns_tx.clone();
148 let root_dir = root_dir.clone();
149 async move {
150 Self::outline(root_dir, &mut foreground_fns_tx, path)
151 .await
152 .into_lua_err()
153 }
154 }
155 })?,
156 )?;
157 globals.set(
158 "sb_io_open",
159 lua.create_async_function({
160 let worktree_info = worktree_info.clone();
161 let foreground_fns_tx = foreground_fns_tx.clone();
162 move |lua, (path_str, mode)| {
163 let worktree_info = worktree_info.clone();
164 let mut foreground_fns_tx = foreground_fns_tx.clone();
165 let fs = fs.clone();
166 async move {
167 Self::io_open(
168 &lua,
169 worktree_info,
170 &mut foreground_fns_tx,
171 fs,
172 path_str,
173 mode,
174 )
175 .await
176 }
177 }
178 })?,
179 )?;
180 globals.set("user_script", script)?;
181
182 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
183
184 // Drop Lua instance to decrement reference count.
185 drop(lua);
186
187 anyhow::Ok(())
188 }
189 });
190
191 task
192 }
193
194 pub fn get(&self, script_id: ScriptId) -> &Script {
195 &self.scripts[script_id.0 as usize]
196 }
197
198 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
199 &mut self.scripts[script_id.0 as usize]
200 }
201
202 /// Sandboxed print() function in Lua.
203 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
204 for (index, arg) in args.into_iter().enumerate() {
205 // Lua's `print()` prints tab characters between each argument.
206 if index > 0 {
207 stdout.lock().push('\t');
208 }
209
210 // If the argument's to_string() fails, have the whole function call fail.
211 stdout.lock().push_str(&arg.to_string()?);
212 }
213 stdout.lock().push('\n');
214
215 Ok(())
216 }
217
218 /// Sandboxed io.open() function in Lua.
219 async fn io_open(
220 lua: &Lua,
221 worktree_info: Option<(WorktreeId, Arc<Path>)>,
222 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
223 fs: Arc<dyn Fs>,
224 path_str: String,
225 mode: Option<String>,
226 ) -> mlua::Result<(Option<Table>, String)> {
227 let (worktree_id, root_dir) = worktree_info
228 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
229
230 let mode = mode.unwrap_or_else(|| "r".to_string());
231
232 // Parse the mode string to determine read/write permissions
233 let read_perm = mode.contains('r');
234 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
235 let append = mode.contains('a');
236 let truncate = mode.contains('w');
237
238 // This will be the Lua value returned from the `open` function.
239 let file = lua.create_table()?;
240
241 // Store file metadata in the file
242 file.set("__mode", mode.clone())?;
243 file.set("__read_perm", read_perm)?;
244 file.set("__write_perm", write_perm)?;
245
246 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
247 Ok(path) => path,
248 Err(err) => return Ok((None, format!("{err}"))),
249 };
250
251 let project_path = ProjectPath {
252 worktree_id,
253 path: Path::new(&path_str).into(),
254 };
255
256 // flush / close method
257 let flush_fn = {
258 let project_path = project_path.clone();
259 let foreground_tx = foreground_tx.clone();
260 lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
261 let project_path = project_path.clone();
262 let mut foreground_tx = foreground_tx.clone();
263 async move {
264 Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
265 }
266 })?
267 };
268 file.set("flush", flush_fn.clone())?;
269 // We don't really hold files open, so we only need to flush on close
270 file.set("close", flush_fn)?;
271
272 // If it's a directory, give it a custom read() and return early.
273 if fs.is_dir(&path).await {
274 return Self::io_file_dir(lua, fs, file, &path).await;
275 }
276
277 let mut file_content = Vec::new();
278
279 if !truncate {
280 // Try to read existing content if we're not truncating
281 match Self::read_buffer(project_path.clone(), foreground_tx).await {
282 Ok(content) => file_content = content.into_bytes(),
283 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
284 }
285 }
286
287 // If in append mode, position should be at the end
288 let position = if append { file_content.len() } else { 0 };
289 file.set("__position", position)?;
290 file.set(
291 "__content",
292 lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
293 )?;
294
295 // Create file methods
296
297 // read method
298 let read_fn = lua.create_function(Self::io_file_read)?;
299 file.set("read", read_fn)?;
300
301 // write method
302 let write_fn = lua.create_function(Self::io_file_write)?;
303 file.set("write", write_fn)?;
304
305 // If we got this far, the file was opened successfully
306 Ok((Some(file), String::new()))
307 }
308
309 async fn read_buffer(
310 project_path: ProjectPath,
311 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
312 ) -> anyhow::Result<String> {
313 Self::run_foreground_fn(
314 "read file from buffer",
315 foreground_tx,
316 Box::new(move |session, mut cx| {
317 session.update(&mut cx, |session, cx| {
318 let open_buffer_task = session
319 .project
320 .update(cx, |project, cx| project.open_buffer(project_path, cx));
321
322 cx.spawn(|_, cx| async move {
323 let buffer = open_buffer_task.await?;
324
325 let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
326 Ok(text)
327 })
328 })
329 }),
330 )
331 .await??
332 .await
333 }
334
335 async fn io_file_flush(
336 file_userdata: mlua::Table,
337 project_path: ProjectPath,
338 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
339 ) -> mlua::Result<bool> {
340 let write_perm = file_userdata.get::<bool>("__write_perm")?;
341
342 if write_perm {
343 // When closing a writable file, record the content
344 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
345 let content_ref = content.borrow::<FileContent>()?;
346 let text = {
347 let mut content_vec = content_ref.0.lock();
348 let content_vec = std::mem::take(&mut *content_vec);
349 String::from_utf8(content_vec).into_lua_err()?
350 };
351
352 Self::write_to_buffer(project_path, text, foreground_tx)
353 .await
354 .into_lua_err()?;
355 }
356
357 Ok(true)
358 }
359
360 async fn write_to_buffer(
361 project_path: ProjectPath,
362 text: String,
363 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
364 ) -> anyhow::Result<()> {
365 Self::run_foreground_fn(
366 "write to buffer",
367 foreground_tx,
368 Box::new(move |session, mut cx| {
369 session.update(&mut cx, |session, cx| {
370 let open_buffer_task = session
371 .project
372 .update(cx, |project, cx| project.open_buffer(project_path, cx));
373
374 cx.spawn(move |session, mut cx| async move {
375 let buffer = open_buffer_task.await?;
376
377 let diff = buffer
378 .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
379 .await;
380
381 buffer.update(&mut cx, |buffer, cx| {
382 buffer.apply_diff(diff, cx);
383 })?;
384
385 session
386 .update(&mut cx, |session, cx| {
387 session
388 .project
389 .update(cx, |project, cx| project.save_buffer(buffer, cx))
390 })?
391 .await
392 })
393 })
394 }),
395 )
396 .await??
397 .await
398 }
399
400 async fn io_file_dir(
401 lua: &Lua,
402 fs: Arc<dyn Fs>,
403 file: Table,
404 path: &Path,
405 ) -> mlua::Result<(Option<Table>, String)> {
406 // Create a special directory handle
407 file.set("__is_directory", true)?;
408
409 // Store directory entries
410 let entries = match fs.read_dir(&path).await {
411 Ok(entries) => {
412 let mut entry_names = Vec::new();
413
414 // Process the stream of directory entries
415 pin_mut!(entries);
416 while let Some(Ok(entry_result)) = entries.next().await {
417 if let Some(file_name) = entry_result.file_name() {
418 entry_names.push(file_name.to_string_lossy().into_owned());
419 }
420 }
421
422 entry_names
423 }
424 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
425 };
426
427 // Save the list of entries
428 file.set("__dir_entries", entries)?;
429 file.set("__dir_position", 0usize)?;
430
431 // Create a directory-specific read function
432 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
433 let position = file_userdata.get::<usize>("__dir_position")?;
434 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
435
436 if position >= entries.len() {
437 return Ok(None); // No more entries
438 }
439
440 let entry = entries[position].clone();
441 file_userdata.set("__dir_position", position + 1)?;
442
443 Ok(Some(entry))
444 })?;
445 file.set("read", read_fn)?;
446
447 // If we got this far, the directory was opened successfully
448 return Ok((Some(file), String::new()));
449 }
450
451 fn io_file_read(
452 lua: &Lua,
453 (file_userdata, format): (Table, Option<mlua::Value>),
454 ) -> mlua::Result<Option<mlua::String>> {
455 let read_perm = file_userdata.get::<bool>("__read_perm")?;
456 if !read_perm {
457 return Err(mlua::Error::runtime("File not open for reading"));
458 }
459
460 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
461 let position = file_userdata.get::<usize>("__position")?;
462 let content_ref = content.borrow::<FileContent>()?;
463 let content = content_ref.0.lock();
464
465 if position >= content.len() {
466 return Ok(None); // EOF
467 }
468
469 let (result, new_position) = match Self::io_file_read_format(format)? {
470 FileReadFormat::All => {
471 // Read entire file from current position
472 let result = content[position..].to_vec();
473 (Some(result), content.len())
474 }
475 FileReadFormat::Line => {
476 if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
477 {
478 let mut line = content[position..position + next_newline_ix].to_vec();
479 if line.ends_with(b"\r") {
480 line.pop();
481 }
482 (Some(line), position + next_newline_ix + 1)
483 } else if position < content.len() {
484 let line = content[position..].to_vec();
485 (Some(line), content.len())
486 } else {
487 (None, position) // EOF
488 }
489 }
490 FileReadFormat::LineWithLineFeed => {
491 if position < content.len() {
492 let next_line_ix = content[position..]
493 .iter()
494 .position(|c| *c == b'\n')
495 .map_or(content.len(), |ix| position + ix + 1);
496 let line = content[position..next_line_ix].to_vec();
497 (Some(line), next_line_ix)
498 } else {
499 (None, position) // EOF
500 }
501 }
502 FileReadFormat::Bytes(n) => {
503 let end = std::cmp::min(position + n, content.len());
504 let result = content[position..end].to_vec();
505 (Some(result), end)
506 }
507 };
508
509 // Update the position in the file userdata
510 if new_position != position {
511 file_userdata.set("__position", new_position)?;
512 }
513
514 // Convert the result to a Lua string
515 match result {
516 Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
517 None => Ok(None),
518 }
519 }
520
521 fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
522 let format = match format {
523 Some(mlua::Value::String(s)) => {
524 let lossy_string = s.to_string_lossy();
525 let format_str: &str = lossy_string.as_ref();
526
527 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
528 match &format_str[0..2] {
529 "*a" => FileReadFormat::All,
530 "*l" => FileReadFormat::Line,
531 "*L" => FileReadFormat::LineWithLineFeed,
532 "*n" => {
533 // Try to parse as a number (number of bytes to read)
534 match format_str.parse::<usize>() {
535 Ok(n) => FileReadFormat::Bytes(n),
536 Err(_) => {
537 return Err(mlua::Error::runtime(format!(
538 "Invalid format: {}",
539 format_str
540 )))
541 }
542 }
543 }
544 _ => {
545 return Err(mlua::Error::runtime(format!(
546 "Unsupported format: {}",
547 format_str
548 )))
549 }
550 }
551 }
552 Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
553 Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
554 Some(value) => {
555 return Err(mlua::Error::runtime(format!(
556 "Invalid file format {:?}",
557 value
558 )))
559 }
560 None => FileReadFormat::Line, // Default is to read a line
561 };
562
563 Ok(format)
564 }
565
566 fn io_file_write(
567 _lua: &Lua,
568 (file_userdata, text): (Table, mlua::String),
569 ) -> mlua::Result<bool> {
570 let write_perm = file_userdata.get::<bool>("__write_perm")?;
571 if !write_perm {
572 return Err(mlua::Error::runtime("File not open for writing"));
573 }
574
575 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
576 let position = file_userdata.get::<usize>("__position")?;
577 let content_ref = content.borrow::<FileContent>()?;
578 let mut content_vec = content_ref.0.lock();
579
580 let bytes = text.as_bytes();
581
582 // Ensure the vector has enough capacity
583 if position + bytes.len() > content_vec.len() {
584 content_vec.resize(position + bytes.len(), 0);
585 }
586
587 // Write the bytes
588 for (i, &byte) in bytes.iter().enumerate() {
589 content_vec[position + i] = byte;
590 }
591
592 // Update position
593 let new_position = position + bytes.len();
594 file_userdata.set("__position", new_position)?;
595
596 Ok(true)
597 }
598
599 async fn search(
600 lua: &Lua,
601 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
602 fs: Arc<dyn Fs>,
603 regex: String,
604 ) -> anyhow::Result<Table> {
605 // TODO: Allow specification of these options.
606 let search_query = SearchQuery::regex(
607 ®ex,
608 false,
609 false,
610 false,
611 PathMatcher::default(),
612 PathMatcher::default(),
613 None,
614 );
615 let search_query = match search_query {
616 Ok(query) => query,
617 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
618 };
619
620 // TODO: Should use `search_query.regex`. The tool description should also be updated,
621 // as it specifies standard regex.
622 let search_regex = match Regex::new(®ex) {
623 Ok(re) => re,
624 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
625 };
626
627 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
628
629 let mut search_results: Vec<Table> = Vec::new();
630 while let Some(path) = abs_paths_rx.next().await {
631 // Skip files larger than 1MB
632 if let Ok(Some(metadata)) = fs.metadata(&path).await {
633 if metadata.len > 1_000_000 {
634 continue;
635 }
636 }
637
638 // Attempt to read the file as text
639 if let Ok(content) = fs.load(&path).await {
640 let mut matches = Vec::new();
641
642 // Find all regex matches in the content
643 for capture in search_regex.find_iter(&content) {
644 matches.push(capture.as_str().to_string());
645 }
646
647 // If we found matches, create a result entry
648 if !matches.is_empty() {
649 let result_entry = lua.create_table()?;
650 result_entry.set("path", path.to_string_lossy().to_string())?;
651
652 let matches_table = lua.create_table()?;
653 for (ix, m) in matches.iter().enumerate() {
654 matches_table.set(ix + 1, m.clone())?;
655 }
656 result_entry.set("matches", matches_table)?;
657
658 search_results.push(result_entry);
659 }
660 }
661 }
662
663 // Create a table to hold our results
664 let results_table = lua.create_table()?;
665 for (ix, entry) in search_results.into_iter().enumerate() {
666 results_table.set(ix + 1, entry)?;
667 }
668
669 Ok(results_table)
670 }
671
672 async fn find_search_candidates(
673 search_query: SearchQuery,
674 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
675 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
676 Self::run_foreground_fn(
677 "finding search file candidates",
678 foreground_tx,
679 Box::new(move |session, mut cx| {
680 session.update(&mut cx, |session, cx| {
681 session.project.update(cx, |project, cx| {
682 project.worktree_store().update(cx, |worktree_store, cx| {
683 // TODO: Better limit? For now this is the same as
684 // MAX_SEARCH_RESULT_FILES.
685 let limit = 5000;
686 // TODO: Providing non-empty open_entries can make this a bit more
687 // efficient as it can skip checking that these paths are textual.
688 let open_entries = HashSet::default();
689 let candidates = worktree_store.find_search_candidates(
690 search_query,
691 limit,
692 open_entries,
693 project.fs().clone(),
694 cx,
695 );
696 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
697 cx.spawn(|worktree_store, cx| async move {
698 pin_mut!(candidates);
699
700 while let Some(project_path) = candidates.next().await {
701 worktree_store.read_with(&cx, |worktree_store, cx| {
702 if let Some(worktree) = worktree_store
703 .worktree_for_id(project_path.worktree_id, cx)
704 {
705 if let Some(abs_path) = worktree
706 .read(cx)
707 .absolutize(&project_path.path)
708 .log_err()
709 {
710 abs_paths_tx.unbounded_send(abs_path)?;
711 }
712 }
713 anyhow::Ok(())
714 })??;
715 }
716 anyhow::Ok(())
717 })
718 .detach();
719 abs_paths_rx
720 })
721 })
722 })
723 }),
724 )
725 .await?
726 }
727
728 async fn outline(
729 root_dir: Option<Arc<Path>>,
730 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
731 path_str: String,
732 ) -> anyhow::Result<String> {
733 let root_dir = root_dir
734 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
735 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
736 let outline = Self::run_foreground_fn(
737 "getting code outline",
738 foreground_tx,
739 Box::new(move |session, cx| {
740 cx.spawn(move |mut cx| async move {
741 // TODO: This will not use file content from `fs_changes`. It will also reflect
742 // user changes that have not been saved.
743 let buffer = session
744 .update(&mut cx, |session, cx| {
745 session
746 .project
747 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
748 })?
749 .await?;
750 buffer.update(&mut cx, |buffer, _cx| {
751 if let Some(outline) = buffer.snapshot().outline(None) {
752 Ok(outline)
753 } else {
754 Err(anyhow!("No outline for file {path_str}"))
755 }
756 })
757 })
758 }),
759 )
760 .await?
761 .await??;
762
763 Ok(outline
764 .items
765 .into_iter()
766 .map(|item| {
767 if item.text.contains('\n') {
768 log::error!("Outline item unexpectedly contains newline");
769 }
770 format!("{}{}", " ".repeat(item.depth), item.text)
771 })
772 .collect::<Vec<String>>()
773 .join("\n"))
774 }
775
776 async fn run_foreground_fn<R: Send + 'static>(
777 description: &str,
778 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
779 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
780 ) -> anyhow::Result<R> {
781 let (response_tx, response_rx) = oneshot::channel();
782 let send_result = foreground_tx
783 .send(ForegroundFn(Box::new(move |this, cx| {
784 response_tx.send(function(this, cx)).ok();
785 })))
786 .await;
787 match send_result {
788 Ok(()) => (),
789 Err(err) => {
790 return Err(anyhow::Error::new(err).context(format!(
791 "Internal error while enqueuing work for {description}"
792 )));
793 }
794 }
795 match response_rx.await {
796 Ok(result) => Ok(result),
797 Err(oneshot::Canceled) => Err(anyhow!(
798 "Internal error: response oneshot was canceled while {description}."
799 )),
800 }
801 }
802
803 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
804 let path = Path::new(&path_str);
805 if path.is_absolute() {
806 // Check if path starts with root_dir prefix without resolving symlinks
807 if path.starts_with(&root_dir) {
808 Ok(path.to_path_buf())
809 } else {
810 Err(anyhow!(
811 "Error: Absolute path {} is outside the current working directory",
812 path_str
813 ))
814 }
815 } else {
816 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
817 Ok(root_dir.join(path))
818 }
819 }
820}
821
822enum FileReadFormat {
823 All,
824 Line,
825 LineWithLineFeed,
826 Bytes(usize),
827}
828
829struct FileContent(Arc<Mutex<Vec<u8>>>);
830
831impl UserData for FileContent {
832 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
833 // FileContent doesn't have any methods so far.
834 }
835}
836
837#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
838pub struct ScriptId(u32);
839
840pub struct Script {
841 pub state: ScriptState,
842}
843
844#[derive(Debug)]
845pub enum ScriptState {
846 Running {
847 stdout: Arc<Mutex<String>>,
848 },
849 Succeeded {
850 stdout: String,
851 },
852 Failed {
853 stdout: String,
854 error: anyhow::Error,
855 },
856}
857
858impl Script {
859 /// If exited, returns a message with the output for the LLM
860 pub fn output_message_for_llm(&self) -> Option<String> {
861 match &self.state {
862 ScriptState::Running { .. } => None,
863 ScriptState::Succeeded { stdout } => {
864 format!("Here's the script output:\n{}", stdout).into()
865 }
866 ScriptState::Failed { stdout, error } => format!(
867 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
868 error, stdout
869 )
870 .into(),
871 }
872 }
873
874 /// Get a snapshot of the script's stdout
875 pub fn stdout_snapshot(&self) -> String {
876 match &self.state {
877 ScriptState::Running { stdout } => stdout.lock().clone(),
878 ScriptState::Succeeded { stdout } => stdout.clone(),
879 ScriptState::Failed { stdout, .. } => stdout.clone(),
880 }
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use gpui::TestAppContext;
887 use project::FakeFs;
888 use serde_json::json;
889 use settings::SettingsStore;
890
891 use super::*;
892
893 #[gpui::test]
894 async fn test_print(cx: &mut TestAppContext) {
895 let script = r#"
896 print("Hello", "world!")
897 print("Goodbye", "moon!")
898 "#;
899
900 let output = test_script(script, cx).await.unwrap();
901 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
902 }
903
904 // search
905
906 #[gpui::test]
907 async fn test_search(cx: &mut TestAppContext) {
908 let script = r#"
909 local results = search("world")
910 for i, result in ipairs(results) do
911 print("File: " .. result.path)
912 print("Matches:")
913 for j, match in ipairs(result.matches) do
914 print(" " .. match)
915 end
916 end
917 "#;
918
919 let output = test_script(script, cx).await.unwrap();
920 assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
921 }
922
923 // io.open
924
925 #[gpui::test]
926 async fn test_open_and_read_file(cx: &mut TestAppContext) {
927 let script = r#"
928 local file = io.open("file1.txt", "r")
929 local content = file:read()
930 print("Content:", content)
931 file:close()
932 "#;
933
934 let output = test_script(script, cx).await.unwrap();
935 assert_eq!(output, "Content:\tHello world!\n");
936 }
937
938 #[gpui::test]
939 async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
940 let script = r#"
941 local file = io.open("new_file.txt", "w")
942 file:write("This is new content")
943 file:close()
944
945 -- Read back to verify
946 local read_file = io.open("new_file.txt", "r")
947 if read_file then
948 local content = read_file:read("*a")
949 print("Written content:", content)
950 read_file:close()
951 end
952 "#;
953
954 let output = test_script(script, cx).await.unwrap();
955 assert_eq!(output, "Written content:\tThis is new content\n");
956 }
957
958 #[gpui::test]
959 async fn test_multiple_writes(cx: &mut TestAppContext) {
960 let script = r#"
961 -- Test writing to a file multiple times
962 local file = io.open("multiwrite.txt", "w")
963 file:write("First line\n")
964 file:write("Second line\n")
965 file:write("Third line")
966 file:close()
967
968 -- Read back to verify
969 local read_file = io.open("multiwrite.txt", "r")
970 if read_file then
971 local content = read_file:read("*a")
972 print("Full content:", content)
973 read_file:close()
974 end
975 "#;
976
977 let output = test_script(script, cx).await.unwrap();
978 assert_eq!(
979 output,
980 "Full content:\tFirst line\nSecond line\nThird line\n"
981 );
982 }
983
984 #[gpui::test]
985 async fn test_read_formats(cx: &mut TestAppContext) {
986 let script = r#"
987 local file = io.open("multiline.txt", "w")
988 file:write("Line 1\nLine 2\nLine 3")
989 file:close()
990
991 -- Test "*a" (all)
992 local f = io.open("multiline.txt", "r")
993 local all = f:read("*a")
994 print("All:", all)
995 f:close()
996
997 -- Test "*l" (line)
998 f = io.open("multiline.txt", "r")
999 local line1 = f:read("*l")
1000 local line2 = f:read("*l")
1001 local line3 = f:read("*l")
1002 print("Line 1:", line1)
1003 print("Line 2:", line2)
1004 print("Line 3:", line3)
1005 f:close()
1006
1007 -- Test "*L" (line with newline)
1008 f = io.open("multiline.txt", "r")
1009 local line_with_nl = f:read("*L")
1010 print("Line with newline length:", #line_with_nl)
1011 print("Last char:", string.byte(line_with_nl, #line_with_nl))
1012 f:close()
1013
1014 -- Test number of bytes
1015 f = io.open("multiline.txt", "r")
1016 local bytes5 = f:read(5)
1017 print("5 bytes:", bytes5)
1018 f:close()
1019 "#;
1020
1021 let output = test_script(script, cx).await.unwrap();
1022 println!("{}", &output);
1023 assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1024 assert!(output.contains("Line 1:\tLine 1"));
1025 assert!(output.contains("Line 2:\tLine 2"));
1026 assert!(output.contains("Line 3:\tLine 3"));
1027 assert!(output.contains("Line with newline length:\t7"));
1028 assert!(output.contains("Last char:\t10")); // LF
1029 assert!(output.contains("5 bytes:\tLine "));
1030 }
1031
1032 // helpers
1033
1034 async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
1035 init_test(cx);
1036 let fs = FakeFs::new(cx.executor());
1037 fs.insert_tree(
1038 "/",
1039 json!({
1040 "file1.txt": "Hello world!",
1041 "file2.txt": "Goodbye moon!"
1042 }),
1043 )
1044 .await;
1045
1046 let project = Project::test(fs, [Path::new("/")], cx).await;
1047 let session = cx.new(|cx| ScriptingSession::new(project, cx));
1048
1049 let (script_id, task) =
1050 session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
1051
1052 task.await;
1053
1054 Ok(session.read_with(cx, |session, _cx| {
1055 let script = session.get(script_id);
1056 let stdout = script.stdout_snapshot();
1057
1058 if let ScriptState::Failed { error, .. } = &script.state {
1059 panic!("Script failed:\n{}\n\n{}", error, stdout);
1060 }
1061
1062 stdout
1063 }))
1064 }
1065
1066 fn init_test(cx: &mut TestAppContext) {
1067 let settings_store = cx.update(SettingsStore::test);
1068 cx.set_global(settings_store);
1069 cx.update(Project::init_settings);
1070 cx.update(language::init);
1071 }
1072}