1use anyhow::anyhow;
2use collections::{HashMap, HashSet};
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
9use parking_lot::Mutex;
10use project::{search::SearchQuery, Fs, Project};
11use regex::Regex;
12use std::{
13 cell::RefCell,
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17use util::{paths::PathMatcher, ResultExt};
18
19use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
20
21struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
22
23pub struct ScriptSession {
24 project: Entity<Project>,
25 // TODO Remove this
26 fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
27 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
28 _invoke_foreground_fns: Task<()>,
29 scripts: Vec<Script>,
30}
31
32impl ScriptSession {
33 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
34 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
35 ScriptSession {
36 project,
37 fs_changes: Arc::new(Mutex::new(HashMap::default())),
38 foreground_fns_tx,
39 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
40 while let Some(foreground_fn) = foreground_fns_rx.next().await {
41 foreground_fn.0(this.clone(), cx.clone());
42 }
43 }),
44 scripts: Vec::new(),
45 }
46 }
47
48 pub fn new_script(&mut self) -> ScriptId {
49 let id = ScriptId(self.scripts.len() as u32);
50 let script = Script {
51 id,
52 state: ScriptState::Generating,
53 source: SharedString::new_static(""),
54 };
55 self.scripts.push(script);
56 id
57 }
58
59 pub fn run_script(
60 &mut self,
61 script_id: ScriptId,
62 script_src: String,
63 cx: &mut Context<Self>,
64 ) -> Task<anyhow::Result<()>> {
65 let script = self.get_mut(script_id);
66
67 let stdout = Arc::new(Mutex::new(String::new()));
68 script.source = script_src.clone().into();
69 script.state = ScriptState::Running {
70 stdout: stdout.clone(),
71 };
72
73 let task = self.run_lua(script_src, stdout, cx);
74
75 cx.emit(ScriptEvent::Spawned(script_id));
76
77 cx.spawn(|session, mut cx| async move {
78 let result = task.await;
79
80 session.update(&mut cx, |session, cx| {
81 let script = session.get_mut(script_id);
82 let stdout = script.stdout_snapshot();
83
84 script.state = match result {
85 Ok(()) => ScriptState::Succeeded { stdout },
86 Err(error) => ScriptState::Failed { stdout, error },
87 };
88
89 cx.emit(ScriptEvent::Exited(script_id))
90 })
91 })
92 }
93
94 fn run_lua(
95 &mut self,
96 script: String,
97 stdout: Arc<Mutex<String>>,
98 cx: &mut Context<Self>,
99 ) -> Task<anyhow::Result<()>> {
100 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
101
102 // TODO Remove fs_changes
103 let fs_changes = self.fs_changes.clone();
104 // TODO Honor all worktrees instead of the first one
105 let root_dir = self
106 .project
107 .read(cx)
108 .visible_worktrees(cx)
109 .next()
110 .map(|worktree| worktree.read(cx).abs_path());
111
112 let fs = self.project.read(cx).fs().clone();
113 let foreground_fns_tx = self.foreground_fns_tx.clone();
114
115 let task = cx.background_spawn({
116 let stdout = stdout.clone();
117
118 async move {
119 let lua = Lua::new();
120 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
121 let globals = lua.globals();
122
123 // Use the project root dir as the script's current working dir.
124 if let Some(root_dir) = &root_dir {
125 if let Some(root_dir) = root_dir.to_str() {
126 globals.set("cwd", root_dir)?;
127 }
128 }
129
130 globals.set(
131 "sb_print",
132 lua.create_function({
133 let stdout = stdout.clone();
134 move |_, args: MultiValue| Self::print(args, &stdout)
135 })?,
136 )?;
137 globals.set(
138 "search",
139 lua.create_async_function({
140 let foreground_fns_tx = foreground_fns_tx.clone();
141 move |lua, regex| {
142 let mut foreground_fns_tx = foreground_fns_tx.clone();
143 let fs = fs.clone();
144 async move {
145 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
146 .await
147 .into_lua_err()
148 }
149 }
150 })?,
151 )?;
152 globals.set(
153 "outline",
154 lua.create_async_function({
155 let root_dir = root_dir.clone();
156 move |_lua, path| {
157 let mut foreground_fns_tx = foreground_fns_tx.clone();
158 let root_dir = root_dir.clone();
159 async move {
160 Self::outline(root_dir, &mut foreground_fns_tx, path)
161 .await
162 .into_lua_err()
163 }
164 }
165 })?,
166 )?;
167 globals.set(
168 "sb_io_open",
169 lua.create_function({
170 let fs_changes = fs_changes.clone();
171 let root_dir = root_dir.clone();
172 move |lua, (path_str, mode)| {
173 Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
174 }
175 })?,
176 )?;
177 globals.set("user_script", script)?;
178
179 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
180
181 // Drop Lua instance to decrement reference count.
182 drop(lua);
183
184 anyhow::Ok(())
185 }
186 });
187
188 task
189 }
190
191 pub fn get(&self, script_id: ScriptId) -> &Script {
192 &self.scripts[script_id.0 as usize]
193 }
194
195 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
196 &mut self.scripts[script_id.0 as usize]
197 }
198
199 /// Sandboxed print() function in Lua.
200 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
201 for (index, arg) in args.into_iter().enumerate() {
202 // Lua's `print()` prints tab characters between each argument.
203 if index > 0 {
204 stdout.lock().push('\t');
205 }
206
207 // If the argument's to_string() fails, have the whole function call fail.
208 stdout.lock().push_str(&arg.to_string()?);
209 }
210 stdout.lock().push('\n');
211
212 Ok(())
213 }
214
215 /// Sandboxed io.open() function in Lua.
216 fn io_open(
217 lua: &Lua,
218 fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
219 root_dir: Option<&Arc<Path>>,
220 path_str: String,
221 mode: Option<String>,
222 ) -> mlua::Result<(Option<Table>, String)> {
223 let root_dir = root_dir
224 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
225
226 let mode = mode.unwrap_or_else(|| "r".to_string());
227
228 // Parse the mode string to determine read/write permissions
229 let read_perm = mode.contains('r');
230 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
231 let append = mode.contains('a');
232 let truncate = mode.contains('w');
233
234 // This will be the Lua value returned from the `open` function.
235 let file = lua.create_table()?;
236
237 // Store file metadata in the file
238 file.set("__path", path_str.clone())?;
239 file.set("__mode", mode.clone())?;
240 file.set("__read_perm", read_perm)?;
241 file.set("__write_perm", write_perm)?;
242
243 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
244 Ok(path) => path,
245 Err(err) => return Ok((None, format!("{err}"))),
246 };
247
248 // close method
249 let close_fn = {
250 let fs_changes = fs_changes.clone();
251 lua.create_function(move |_lua, file_userdata: mlua::Table| {
252 let write_perm = file_userdata.get::<bool>("__write_perm")?;
253 let path = file_userdata.get::<String>("__path")?;
254
255 if write_perm {
256 // When closing a writable file, record the content
257 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
258 let content_ref = content.borrow::<FileContent>()?;
259 let content_vec = content_ref.0.borrow();
260
261 // Don't actually write to disk; instead, just update fs_changes.
262 let path_buf = PathBuf::from(&path);
263 fs_changes
264 .lock()
265 .insert(path_buf.clone(), content_vec.clone());
266 }
267
268 Ok(true)
269 })?
270 };
271 file.set("close", close_fn)?;
272
273 // If it's a directory, give it a custom read() and return early.
274 if path.is_dir() {
275 // TODO handle the case where we changed it in the in-memory fs
276
277 // Create a special directory handle
278 file.set("__is_directory", true)?;
279
280 // Store directory entries
281 let entries = match std::fs::read_dir(&path) {
282 Ok(entries) => {
283 let mut entry_names = Vec::new();
284 for entry in entries.flatten() {
285 entry_names.push(entry.file_name().to_string_lossy().into_owned());
286 }
287 entry_names
288 }
289 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
290 };
291
292 // Save the list of entries
293 file.set("__dir_entries", entries)?;
294 file.set("__dir_position", 0usize)?;
295
296 // Create a directory-specific read function
297 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
298 let position = file_userdata.get::<usize>("__dir_position")?;
299 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
300
301 if position >= entries.len() {
302 return Ok(None); // No more entries
303 }
304
305 let entry = entries[position].clone();
306 file_userdata.set("__dir_position", position + 1)?;
307
308 Ok(Some(entry))
309 })?;
310 file.set("read", read_fn)?;
311
312 // If we got this far, the directory was opened successfully
313 return Ok((Some(file), String::new()));
314 }
315
316 let fs_changes_map = fs_changes.lock();
317
318 let is_in_changes = fs_changes_map.contains_key(&path);
319 let file_exists = is_in_changes || path.exists();
320 let mut file_content = Vec::new();
321
322 if file_exists && !truncate {
323 if is_in_changes {
324 file_content = fs_changes_map.get(&path).unwrap().clone();
325 } else {
326 // Try to read existing content if file exists and we're not truncating
327 match std::fs::read(&path) {
328 Ok(content) => file_content = content,
329 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
330 }
331 }
332 }
333
334 drop(fs_changes_map); // Unlock the fs_changes mutex.
335
336 // If in append mode, position should be at the end
337 let position = if append && file_exists {
338 file_content.len()
339 } else {
340 0
341 };
342 file.set("__position", position)?;
343 file.set(
344 "__content",
345 lua.create_userdata(FileContent(RefCell::new(file_content)))?,
346 )?;
347
348 // Create file methods
349
350 // read method
351 let read_fn = {
352 lua.create_function(
353 |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
354 let read_perm = file_userdata.get::<bool>("__read_perm")?;
355 if !read_perm {
356 return Err(mlua::Error::runtime("File not open for reading"));
357 }
358
359 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
360 let mut position = file_userdata.get::<usize>("__position")?;
361 let content_ref = content.borrow::<FileContent>()?;
362 let content_vec = content_ref.0.borrow();
363
364 if position >= content_vec.len() {
365 return Ok(None); // EOF
366 }
367
368 match format {
369 Some(mlua::Value::String(s)) => {
370 let lossy_string = s.to_string_lossy();
371 let format_str: &str = lossy_string.as_ref();
372
373 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
374 match &format_str[0..2] {
375 "*a" => {
376 // Read entire file from current position
377 let result = String::from_utf8_lossy(&content_vec[position..])
378 .to_string();
379 position = content_vec.len();
380 file_userdata.set("__position", position)?;
381 Ok(Some(result))
382 }
383 "*l" => {
384 // Read next line
385 let mut line = Vec::new();
386 let mut found_newline = false;
387
388 while position < content_vec.len() {
389 let byte = content_vec[position];
390 position += 1;
391
392 if byte == b'\n' {
393 found_newline = true;
394 break;
395 }
396
397 // Skip \r in \r\n sequence but add it if it's alone
398 if byte == b'\r' {
399 if position < content_vec.len()
400 && content_vec[position] == b'\n'
401 {
402 position += 1;
403 found_newline = true;
404 break;
405 }
406 }
407
408 line.push(byte);
409 }
410
411 file_userdata.set("__position", position)?;
412
413 if !found_newline
414 && line.is_empty()
415 && position >= content_vec.len()
416 {
417 return Ok(None); // EOF
418 }
419
420 let result = String::from_utf8_lossy(&line).to_string();
421 Ok(Some(result))
422 }
423 "*n" => {
424 // Try to parse as a number (number of bytes to read)
425 match format_str.parse::<usize>() {
426 Ok(n) => {
427 let end =
428 std::cmp::min(position + n, content_vec.len());
429 let bytes = &content_vec[position..end];
430 let result = String::from_utf8_lossy(bytes).to_string();
431 position = end;
432 file_userdata.set("__position", position)?;
433 Ok(Some(result))
434 }
435 Err(_) => Err(mlua::Error::runtime(format!(
436 "Invalid format: {}",
437 format_str
438 ))),
439 }
440 }
441 "*L" => {
442 // Read next line keeping the end of line
443 let mut line = Vec::new();
444
445 while position < content_vec.len() {
446 let byte = content_vec[position];
447 position += 1;
448
449 line.push(byte);
450
451 if byte == b'\n' {
452 break;
453 }
454
455 // If we encounter a \r, add it and check if the next is \n
456 if byte == b'\r'
457 && position < content_vec.len()
458 && content_vec[position] == b'\n'
459 {
460 line.push(content_vec[position]);
461 position += 1;
462 break;
463 }
464 }
465
466 file_userdata.set("__position", position)?;
467
468 if line.is_empty() && position >= content_vec.len() {
469 return Ok(None); // EOF
470 }
471
472 let result = String::from_utf8_lossy(&line).to_string();
473 Ok(Some(result))
474 }
475 _ => Err(mlua::Error::runtime(format!(
476 "Unsupported format: {}",
477 format_str
478 ))),
479 }
480 }
481 Some(mlua::Value::Number(n)) => {
482 // Read n bytes
483 let n = n as usize;
484 let end = std::cmp::min(position + n, content_vec.len());
485 let bytes = &content_vec[position..end];
486 let result = String::from_utf8_lossy(bytes).to_string();
487 position = end;
488 file_userdata.set("__position", position)?;
489 Ok(Some(result))
490 }
491 Some(_) => Err(mlua::Error::runtime("Invalid format")),
492 None => {
493 // Default is to read a line
494 let mut line = Vec::new();
495 let mut found_newline = false;
496
497 while position < content_vec.len() {
498 let byte = content_vec[position];
499 position += 1;
500
501 if byte == b'\n' {
502 found_newline = true;
503 break;
504 }
505
506 // Handle \r\n
507 if byte == b'\r' {
508 if position < content_vec.len()
509 && content_vec[position] == b'\n'
510 {
511 position += 1;
512 found_newline = true;
513 break;
514 }
515 }
516
517 line.push(byte);
518 }
519
520 file_userdata.set("__position", position)?;
521
522 if !found_newline && line.is_empty() && position >= content_vec.len() {
523 return Ok(None); // EOF
524 }
525
526 let result = String::from_utf8_lossy(&line).to_string();
527 Ok(Some(result))
528 }
529 }
530 },
531 )?
532 };
533 file.set("read", read_fn)?;
534
535 // write method
536 let write_fn = {
537 let fs_changes = fs_changes.clone();
538
539 lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
540 let write_perm = file_userdata.get::<bool>("__write_perm")?;
541 if !write_perm {
542 return Err(mlua::Error::runtime("File not open for writing"));
543 }
544
545 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
546 let position = file_userdata.get::<usize>("__position")?;
547 let content_ref = content.borrow::<FileContent>()?;
548 let mut content_vec = content_ref.0.borrow_mut();
549
550 let bytes = text.as_bytes();
551
552 // Ensure the vector has enough capacity
553 if position + bytes.len() > content_vec.len() {
554 content_vec.resize(position + bytes.len(), 0);
555 }
556
557 // Write the bytes
558 for (i, &byte) in bytes.iter().enumerate() {
559 content_vec[position + i] = byte;
560 }
561
562 // Update position
563 let new_position = position + bytes.len();
564 file_userdata.set("__position", new_position)?;
565
566 // Update fs_changes
567 let path = file_userdata.get::<String>("__path")?;
568 let path_buf = PathBuf::from(path);
569 fs_changes.lock().insert(path_buf, content_vec.clone());
570
571 Ok(true)
572 })?
573 };
574 file.set("write", write_fn)?;
575
576 // If we got this far, the file was opened successfully
577 Ok((Some(file), String::new()))
578 }
579
580 async fn search(
581 lua: &Lua,
582 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
583 fs: Arc<dyn Fs>,
584 regex: String,
585 ) -> anyhow::Result<Table> {
586 // TODO: Allow specification of these options.
587 let search_query = SearchQuery::regex(
588 ®ex,
589 false,
590 false,
591 false,
592 PathMatcher::default(),
593 PathMatcher::default(),
594 None,
595 );
596 let search_query = match search_query {
597 Ok(query) => query,
598 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
599 };
600
601 // TODO: Should use `search_query.regex`. The tool description should also be updated,
602 // as it specifies standard regex.
603 let search_regex = match Regex::new(®ex) {
604 Ok(re) => re,
605 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
606 };
607
608 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
609
610 let mut search_results: Vec<Table> = Vec::new();
611 while let Some(path) = abs_paths_rx.next().await {
612 // Skip files larger than 1MB
613 if let Ok(Some(metadata)) = fs.metadata(&path).await {
614 if metadata.len > 1_000_000 {
615 continue;
616 }
617 }
618
619 // Attempt to read the file as text
620 if let Ok(content) = fs.load(&path).await {
621 let mut matches = Vec::new();
622
623 // Find all regex matches in the content
624 for capture in search_regex.find_iter(&content) {
625 matches.push(capture.as_str().to_string());
626 }
627
628 // If we found matches, create a result entry
629 if !matches.is_empty() {
630 let result_entry = lua.create_table()?;
631 result_entry.set("path", path.to_string_lossy().to_string())?;
632
633 let matches_table = lua.create_table()?;
634 for (ix, m) in matches.iter().enumerate() {
635 matches_table.set(ix + 1, m.clone())?;
636 }
637 result_entry.set("matches", matches_table)?;
638
639 search_results.push(result_entry);
640 }
641 }
642 }
643
644 // Create a table to hold our results
645 let results_table = lua.create_table()?;
646 for (ix, entry) in search_results.into_iter().enumerate() {
647 results_table.set(ix + 1, entry)?;
648 }
649
650 Ok(results_table)
651 }
652
653 async fn find_search_candidates(
654 search_query: SearchQuery,
655 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
656 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
657 Self::run_foreground_fn(
658 "finding search file candidates",
659 foreground_tx,
660 Box::new(move |session, mut cx| {
661 session.update(&mut cx, |session, cx| {
662 session.project.update(cx, |project, cx| {
663 project.worktree_store().update(cx, |worktree_store, cx| {
664 // TODO: Better limit? For now this is the same as
665 // MAX_SEARCH_RESULT_FILES.
666 let limit = 5000;
667 // TODO: Providing non-empty open_entries can make this a bit more
668 // efficient as it can skip checking that these paths are textual.
669 let open_entries = HashSet::default();
670 let candidates = worktree_store.find_search_candidates(
671 search_query,
672 limit,
673 open_entries,
674 project.fs().clone(),
675 cx,
676 );
677 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
678 cx.spawn(|worktree_store, cx| async move {
679 pin_mut!(candidates);
680
681 while let Some(project_path) = candidates.next().await {
682 worktree_store.read_with(&cx, |worktree_store, cx| {
683 if let Some(worktree) = worktree_store
684 .worktree_for_id(project_path.worktree_id, cx)
685 {
686 if let Some(abs_path) = worktree
687 .read(cx)
688 .absolutize(&project_path.path)
689 .log_err()
690 {
691 abs_paths_tx.unbounded_send(abs_path)?;
692 }
693 }
694 anyhow::Ok(())
695 })??;
696 }
697 anyhow::Ok(())
698 })
699 .detach();
700 abs_paths_rx
701 })
702 })
703 })
704 }),
705 )
706 .await?
707 }
708
709 async fn outline(
710 root_dir: Option<Arc<Path>>,
711 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
712 path_str: String,
713 ) -> anyhow::Result<String> {
714 let root_dir = root_dir
715 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
716 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
717 let outline = Self::run_foreground_fn(
718 "getting code outline",
719 foreground_tx,
720 Box::new(move |session, cx| {
721 cx.spawn(move |mut cx| async move {
722 // TODO: This will not use file content from `fs_changes`. It will also reflect
723 // user changes that have not been saved.
724 let buffer = session
725 .update(&mut cx, |session, cx| {
726 session
727 .project
728 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
729 })?
730 .await?;
731 buffer.update(&mut cx, |buffer, _cx| {
732 if let Some(outline) = buffer.snapshot().outline(None) {
733 Ok(outline)
734 } else {
735 Err(anyhow!("No outline for file {path_str}"))
736 }
737 })
738 })
739 }),
740 )
741 .await?
742 .await??;
743
744 Ok(outline
745 .items
746 .into_iter()
747 .map(|item| {
748 if item.text.contains('\n') {
749 log::error!("Outline item unexpectedly contains newline");
750 }
751 format!("{}{}", " ".repeat(item.depth), item.text)
752 })
753 .collect::<Vec<String>>()
754 .join("\n"))
755 }
756
757 async fn run_foreground_fn<R: Send + 'static>(
758 description: &str,
759 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
760 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
761 ) -> anyhow::Result<R> {
762 let (response_tx, response_rx) = oneshot::channel();
763 let send_result = foreground_tx
764 .send(ForegroundFn(Box::new(move |this, cx| {
765 response_tx.send(function(this, cx)).ok();
766 })))
767 .await;
768 match send_result {
769 Ok(()) => (),
770 Err(err) => {
771 return Err(anyhow::Error::new(err).context(format!(
772 "Internal error while enqueuing work for {description}"
773 )));
774 }
775 }
776 match response_rx.await {
777 Ok(result) => Ok(result),
778 Err(oneshot::Canceled) => Err(anyhow!(
779 "Internal error: response oneshot was canceled while {description}."
780 )),
781 }
782 }
783
784 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
785 let path = Path::new(&path_str);
786 if path.is_absolute() {
787 // Check if path starts with root_dir prefix without resolving symlinks
788 if path.starts_with(&root_dir) {
789 Ok(path.to_path_buf())
790 } else {
791 Err(anyhow!(
792 "Error: Absolute path {} is outside the current working directory",
793 path_str
794 ))
795 }
796 } else {
797 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
798 Ok(root_dir.join(path))
799 }
800 }
801}
802
803struct FileContent(RefCell<Vec<u8>>);
804
805impl UserData for FileContent {
806 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
807 // FileContent doesn't have any methods so far.
808 }
809}
810
811#[derive(Debug)]
812pub enum ScriptEvent {
813 Spawned(ScriptId),
814 Exited(ScriptId),
815}
816
817impl EventEmitter<ScriptEvent> for ScriptSession {}
818
819#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
820pub struct ScriptId(u32);
821
822pub struct Script {
823 pub id: ScriptId,
824 pub state: ScriptState,
825 pub source: SharedString,
826}
827
828pub enum ScriptState {
829 Generating,
830 Running {
831 stdout: Arc<Mutex<String>>,
832 },
833 Succeeded {
834 stdout: String,
835 },
836 Failed {
837 stdout: String,
838 error: anyhow::Error,
839 },
840}
841
842impl Script {
843 pub fn source_tag(&self) -> String {
844 format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
845 }
846
847 /// If exited, returns a message with the output for the LLM
848 pub fn output_message_for_llm(&self) -> Option<String> {
849 match &self.state {
850 ScriptState::Generating { .. } => None,
851 ScriptState::Running { .. } => None,
852 ScriptState::Succeeded { stdout } => {
853 format!("Here's the script output:\n{}", stdout).into()
854 }
855 ScriptState::Failed { stdout, error } => format!(
856 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
857 error, stdout
858 )
859 .into(),
860 }
861 }
862
863 /// Get a snapshot of the script's stdout
864 pub fn stdout_snapshot(&self) -> String {
865 match &self.state {
866 ScriptState::Generating { .. } => String::new(),
867 ScriptState::Running { stdout } => stdout.lock().clone(),
868 ScriptState::Succeeded { stdout } => stdout.clone(),
869 ScriptState::Failed { stdout, .. } => stdout.clone(),
870 }
871 }
872
873 /// Returns the error if the script failed, otherwise None
874 pub fn error(&self) -> Option<&anyhow::Error> {
875 match &self.state {
876 ScriptState::Generating { .. } => None,
877 ScriptState::Running { .. } => None,
878 ScriptState::Succeeded { .. } => None,
879 ScriptState::Failed { error, .. } => Some(error),
880 }
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use gpui::TestAppContext;
887 use project::FakeFs;
888 use serde_json::json;
889 use settings::SettingsStore;
890
891 use super::*;
892
893 #[gpui::test]
894 async fn test_print(cx: &mut TestAppContext) {
895 let script = r#"
896 print("Hello", "world!")
897 print("Goodbye", "moon!")
898 "#;
899
900 let output = test_script(script, cx).await.unwrap();
901 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
902 }
903
904 #[gpui::test]
905 async fn test_search(cx: &mut TestAppContext) {
906 let script = r#"
907 local results = search("world")
908 for i, result in ipairs(results) do
909 print("File: " .. result.path)
910 print("Matches:")
911 for j, match in ipairs(result.matches) do
912 print(" " .. match)
913 end
914 end
915 "#;
916
917 let output = test_script(script, cx).await.unwrap();
918 assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
919 }
920
921 async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
922 init_test(cx);
923 let fs = FakeFs::new(cx.executor());
924 fs.insert_tree(
925 "/",
926 json!({
927 "file1.txt": "Hello world!",
928 "file2.txt": "Goodbye moon!"
929 }),
930 )
931 .await;
932
933 let project = Project::test(fs, [Path::new("/")], cx).await;
934 let session = cx.new(|cx| ScriptSession::new(project, cx));
935
936 let (script_id, task) = session.update(cx, |session, cx| {
937 let script_id = session.new_script();
938 let task = session.run_script(script_id, source.to_string(), cx);
939
940 (script_id, task)
941 });
942
943 task.await?;
944
945 Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
946 }
947
948 fn init_test(cx: &mut TestAppContext) {
949 let settings_store = cx.update(SettingsStore::test);
950 cx.set_global(settings_store);
951 cx.update(Project::init_settings);
952 }
953}