1use anyhow::anyhow;
2use collections::{HashMap, HashSet};
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
9use parking_lot::Mutex;
10use project::{search::SearchQuery, Fs, Project};
11use regex::Regex;
12use std::{
13 cell::RefCell,
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17use util::{paths::PathMatcher, ResultExt};
18
19struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
20
21pub struct ScriptSession {
22 project: Entity<Project>,
23 // TODO Remove this
24 fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
25 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
26 _invoke_foreground_fns: Task<()>,
27 scripts: Vec<Script>,
28}
29
30impl ScriptSession {
31 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
32 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
33 ScriptSession {
34 project,
35 fs_changes: Arc::new(Mutex::new(HashMap::default())),
36 foreground_fns_tx,
37 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
38 while let Some(foreground_fn) = foreground_fns_rx.next().await {
39 foreground_fn.0(this.clone(), cx.clone());
40 }
41 }),
42 scripts: Vec::new(),
43 }
44 }
45
46 pub fn run_script(
47 &mut self,
48 script_src: String,
49 cx: &mut Context<Self>,
50 ) -> (ScriptId, Task<()>) {
51 let id = ScriptId(self.scripts.len() as u32);
52
53 let stdout = Arc::new(Mutex::new(String::new()));
54
55 let script = Script {
56 state: ScriptState::Running {
57 stdout: stdout.clone(),
58 },
59 };
60 self.scripts.push(script);
61
62 let task = self.run_lua(script_src, stdout, cx);
63
64 let task = cx.spawn(|session, mut cx| async move {
65 let result = task.await;
66
67 session
68 .update(&mut cx, |session, _cx| {
69 let script = session.get_mut(id);
70 let stdout = script.stdout_snapshot();
71
72 script.state = match result {
73 Ok(()) => ScriptState::Succeeded { stdout },
74 Err(error) => ScriptState::Failed { stdout, error },
75 };
76 })
77 .log_err();
78 });
79
80 (id, task)
81 }
82
83 fn run_lua(
84 &mut self,
85 script: String,
86 stdout: Arc<Mutex<String>>,
87 cx: &mut Context<Self>,
88 ) -> Task<anyhow::Result<()>> {
89 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
90
91 // TODO Remove fs_changes
92 let fs_changes = self.fs_changes.clone();
93 // TODO Honor all worktrees instead of the first one
94 let root_dir = self
95 .project
96 .read(cx)
97 .visible_worktrees(cx)
98 .next()
99 .map(|worktree| worktree.read(cx).abs_path());
100
101 let fs = self.project.read(cx).fs().clone();
102 let foreground_fns_tx = self.foreground_fns_tx.clone();
103
104 let task = cx.background_spawn({
105 let stdout = stdout.clone();
106
107 async move {
108 let lua = Lua::new();
109 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
110 let globals = lua.globals();
111
112 // Use the project root dir as the script's current working dir.
113 if let Some(root_dir) = &root_dir {
114 if let Some(root_dir) = root_dir.to_str() {
115 globals.set("cwd", root_dir)?;
116 }
117 }
118
119 globals.set(
120 "sb_print",
121 lua.create_function({
122 let stdout = stdout.clone();
123 move |_, args: MultiValue| Self::print(args, &stdout)
124 })?,
125 )?;
126 globals.set(
127 "search",
128 lua.create_async_function({
129 let foreground_fns_tx = foreground_fns_tx.clone();
130 move |lua, regex| {
131 let mut foreground_fns_tx = foreground_fns_tx.clone();
132 let fs = fs.clone();
133 async move {
134 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
135 .await
136 .into_lua_err()
137 }
138 }
139 })?,
140 )?;
141 globals.set(
142 "outline",
143 lua.create_async_function({
144 let root_dir = root_dir.clone();
145 move |_lua, path| {
146 let mut foreground_fns_tx = foreground_fns_tx.clone();
147 let root_dir = root_dir.clone();
148 async move {
149 Self::outline(root_dir, &mut foreground_fns_tx, path)
150 .await
151 .into_lua_err()
152 }
153 }
154 })?,
155 )?;
156 globals.set(
157 "sb_io_open",
158 lua.create_function({
159 let fs_changes = fs_changes.clone();
160 let root_dir = root_dir.clone();
161 move |lua, (path_str, mode)| {
162 Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
163 }
164 })?,
165 )?;
166 globals.set("user_script", script)?;
167
168 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
169
170 // Drop Lua instance to decrement reference count.
171 drop(lua);
172
173 anyhow::Ok(())
174 }
175 });
176
177 task
178 }
179
180 pub fn get(&self, script_id: ScriptId) -> &Script {
181 &self.scripts[script_id.0 as usize]
182 }
183
184 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
185 &mut self.scripts[script_id.0 as usize]
186 }
187
188 /// Sandboxed print() function in Lua.
189 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
190 for (index, arg) in args.into_iter().enumerate() {
191 // Lua's `print()` prints tab characters between each argument.
192 if index > 0 {
193 stdout.lock().push('\t');
194 }
195
196 // If the argument's to_string() fails, have the whole function call fail.
197 stdout.lock().push_str(&arg.to_string()?);
198 }
199 stdout.lock().push('\n');
200
201 Ok(())
202 }
203
204 /// Sandboxed io.open() function in Lua.
205 fn io_open(
206 lua: &Lua,
207 fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
208 root_dir: Option<&Arc<Path>>,
209 path_str: String,
210 mode: Option<String>,
211 ) -> mlua::Result<(Option<Table>, String)> {
212 let root_dir = root_dir
213 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
214
215 let mode = mode.unwrap_or_else(|| "r".to_string());
216
217 // Parse the mode string to determine read/write permissions
218 let read_perm = mode.contains('r');
219 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
220 let append = mode.contains('a');
221 let truncate = mode.contains('w');
222
223 // This will be the Lua value returned from the `open` function.
224 let file = lua.create_table()?;
225
226 // Store file metadata in the file
227 file.set("__path", path_str.clone())?;
228 file.set("__mode", mode.clone())?;
229 file.set("__read_perm", read_perm)?;
230 file.set("__write_perm", write_perm)?;
231
232 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
233 Ok(path) => path,
234 Err(err) => return Ok((None, format!("{err}"))),
235 };
236
237 // close method
238 let close_fn = {
239 let fs_changes = fs_changes.clone();
240 lua.create_function(move |_lua, file_userdata: mlua::Table| {
241 let write_perm = file_userdata.get::<bool>("__write_perm")?;
242 let path = file_userdata.get::<String>("__path")?;
243
244 if write_perm {
245 // When closing a writable file, record the content
246 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
247 let content_ref = content.borrow::<FileContent>()?;
248 let content_vec = content_ref.0.borrow();
249
250 // Don't actually write to disk; instead, just update fs_changes.
251 let path_buf = PathBuf::from(&path);
252 fs_changes
253 .lock()
254 .insert(path_buf.clone(), content_vec.clone());
255 }
256
257 Ok(true)
258 })?
259 };
260 file.set("close", close_fn)?;
261
262 // If it's a directory, give it a custom read() and return early.
263 if path.is_dir() {
264 // TODO handle the case where we changed it in the in-memory fs
265
266 // Create a special directory handle
267 file.set("__is_directory", true)?;
268
269 // Store directory entries
270 let entries = match std::fs::read_dir(&path) {
271 Ok(entries) => {
272 let mut entry_names = Vec::new();
273 for entry in entries.flatten() {
274 entry_names.push(entry.file_name().to_string_lossy().into_owned());
275 }
276 entry_names
277 }
278 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
279 };
280
281 // Save the list of entries
282 file.set("__dir_entries", entries)?;
283 file.set("__dir_position", 0usize)?;
284
285 // Create a directory-specific read function
286 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
287 let position = file_userdata.get::<usize>("__dir_position")?;
288 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
289
290 if position >= entries.len() {
291 return Ok(None); // No more entries
292 }
293
294 let entry = entries[position].clone();
295 file_userdata.set("__dir_position", position + 1)?;
296
297 Ok(Some(entry))
298 })?;
299 file.set("read", read_fn)?;
300
301 // If we got this far, the directory was opened successfully
302 return Ok((Some(file), String::new()));
303 }
304
305 let fs_changes_map = fs_changes.lock();
306
307 let is_in_changes = fs_changes_map.contains_key(&path);
308 let file_exists = is_in_changes || path.exists();
309 let mut file_content = Vec::new();
310
311 if file_exists && !truncate {
312 if is_in_changes {
313 file_content = fs_changes_map.get(&path).unwrap().clone();
314 } else {
315 // Try to read existing content if file exists and we're not truncating
316 match std::fs::read(&path) {
317 Ok(content) => file_content = content,
318 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
319 }
320 }
321 }
322
323 drop(fs_changes_map); // Unlock the fs_changes mutex.
324
325 // If in append mode, position should be at the end
326 let position = if append && file_exists {
327 file_content.len()
328 } else {
329 0
330 };
331 file.set("__position", position)?;
332 file.set(
333 "__content",
334 lua.create_userdata(FileContent(RefCell::new(file_content)))?,
335 )?;
336
337 // Create file methods
338
339 // read method
340 let read_fn = {
341 lua.create_function(
342 |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
343 let read_perm = file_userdata.get::<bool>("__read_perm")?;
344 if !read_perm {
345 return Err(mlua::Error::runtime("File not open for reading"));
346 }
347
348 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
349 let mut position = file_userdata.get::<usize>("__position")?;
350 let content_ref = content.borrow::<FileContent>()?;
351 let content_vec = content_ref.0.borrow();
352
353 if position >= content_vec.len() {
354 return Ok(None); // EOF
355 }
356
357 match format {
358 Some(mlua::Value::String(s)) => {
359 let lossy_string = s.to_string_lossy();
360 let format_str: &str = lossy_string.as_ref();
361
362 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
363 match &format_str[0..2] {
364 "*a" => {
365 // Read entire file from current position
366 let result = String::from_utf8_lossy(&content_vec[position..])
367 .to_string();
368 position = content_vec.len();
369 file_userdata.set("__position", position)?;
370 Ok(Some(result))
371 }
372 "*l" => {
373 // Read next line
374 let mut line = Vec::new();
375 let mut found_newline = false;
376
377 while position < content_vec.len() {
378 let byte = content_vec[position];
379 position += 1;
380
381 if byte == b'\n' {
382 found_newline = true;
383 break;
384 }
385
386 // Skip \r in \r\n sequence but add it if it's alone
387 if byte == b'\r' {
388 if position < content_vec.len()
389 && content_vec[position] == b'\n'
390 {
391 position += 1;
392 found_newline = true;
393 break;
394 }
395 }
396
397 line.push(byte);
398 }
399
400 file_userdata.set("__position", position)?;
401
402 if !found_newline
403 && line.is_empty()
404 && position >= content_vec.len()
405 {
406 return Ok(None); // EOF
407 }
408
409 let result = String::from_utf8_lossy(&line).to_string();
410 Ok(Some(result))
411 }
412 "*n" => {
413 // Try to parse as a number (number of bytes to read)
414 match format_str.parse::<usize>() {
415 Ok(n) => {
416 let end =
417 std::cmp::min(position + n, content_vec.len());
418 let bytes = &content_vec[position..end];
419 let result = String::from_utf8_lossy(bytes).to_string();
420 position = end;
421 file_userdata.set("__position", position)?;
422 Ok(Some(result))
423 }
424 Err(_) => Err(mlua::Error::runtime(format!(
425 "Invalid format: {}",
426 format_str
427 ))),
428 }
429 }
430 "*L" => {
431 // Read next line keeping the end of line
432 let mut line = Vec::new();
433
434 while position < content_vec.len() {
435 let byte = content_vec[position];
436 position += 1;
437
438 line.push(byte);
439
440 if byte == b'\n' {
441 break;
442 }
443
444 // If we encounter a \r, add it and check if the next is \n
445 if byte == b'\r'
446 && position < content_vec.len()
447 && content_vec[position] == b'\n'
448 {
449 line.push(content_vec[position]);
450 position += 1;
451 break;
452 }
453 }
454
455 file_userdata.set("__position", position)?;
456
457 if line.is_empty() && position >= content_vec.len() {
458 return Ok(None); // EOF
459 }
460
461 let result = String::from_utf8_lossy(&line).to_string();
462 Ok(Some(result))
463 }
464 _ => Err(mlua::Error::runtime(format!(
465 "Unsupported format: {}",
466 format_str
467 ))),
468 }
469 }
470 Some(mlua::Value::Number(n)) => {
471 // Read n bytes
472 let n = n as usize;
473 let end = std::cmp::min(position + n, content_vec.len());
474 let bytes = &content_vec[position..end];
475 let result = String::from_utf8_lossy(bytes).to_string();
476 position = end;
477 file_userdata.set("__position", position)?;
478 Ok(Some(result))
479 }
480 Some(_) => Err(mlua::Error::runtime("Invalid format")),
481 None => {
482 // Default is to read a line
483 let mut line = Vec::new();
484 let mut found_newline = false;
485
486 while position < content_vec.len() {
487 let byte = content_vec[position];
488 position += 1;
489
490 if byte == b'\n' {
491 found_newline = true;
492 break;
493 }
494
495 // Handle \r\n
496 if byte == b'\r' {
497 if position < content_vec.len()
498 && content_vec[position] == b'\n'
499 {
500 position += 1;
501 found_newline = true;
502 break;
503 }
504 }
505
506 line.push(byte);
507 }
508
509 file_userdata.set("__position", position)?;
510
511 if !found_newline && line.is_empty() && position >= content_vec.len() {
512 return Ok(None); // EOF
513 }
514
515 let result = String::from_utf8_lossy(&line).to_string();
516 Ok(Some(result))
517 }
518 }
519 },
520 )?
521 };
522 file.set("read", read_fn)?;
523
524 // write method
525 let write_fn = {
526 let fs_changes = fs_changes.clone();
527
528 lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
529 let write_perm = file_userdata.get::<bool>("__write_perm")?;
530 if !write_perm {
531 return Err(mlua::Error::runtime("File not open for writing"));
532 }
533
534 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
535 let position = file_userdata.get::<usize>("__position")?;
536 let content_ref = content.borrow::<FileContent>()?;
537 let mut content_vec = content_ref.0.borrow_mut();
538
539 let bytes = text.as_bytes();
540
541 // Ensure the vector has enough capacity
542 if position + bytes.len() > content_vec.len() {
543 content_vec.resize(position + bytes.len(), 0);
544 }
545
546 // Write the bytes
547 for (i, &byte) in bytes.iter().enumerate() {
548 content_vec[position + i] = byte;
549 }
550
551 // Update position
552 let new_position = position + bytes.len();
553 file_userdata.set("__position", new_position)?;
554
555 // Update fs_changes
556 let path = file_userdata.get::<String>("__path")?;
557 let path_buf = PathBuf::from(path);
558 fs_changes.lock().insert(path_buf, content_vec.clone());
559
560 Ok(true)
561 })?
562 };
563 file.set("write", write_fn)?;
564
565 // If we got this far, the file was opened successfully
566 Ok((Some(file), String::new()))
567 }
568
569 async fn search(
570 lua: &Lua,
571 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
572 fs: Arc<dyn Fs>,
573 regex: String,
574 ) -> anyhow::Result<Table> {
575 // TODO: Allow specification of these options.
576 let search_query = SearchQuery::regex(
577 ®ex,
578 false,
579 false,
580 false,
581 PathMatcher::default(),
582 PathMatcher::default(),
583 None,
584 );
585 let search_query = match search_query {
586 Ok(query) => query,
587 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
588 };
589
590 // TODO: Should use `search_query.regex`. The tool description should also be updated,
591 // as it specifies standard regex.
592 let search_regex = match Regex::new(®ex) {
593 Ok(re) => re,
594 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
595 };
596
597 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
598
599 let mut search_results: Vec<Table> = Vec::new();
600 while let Some(path) = abs_paths_rx.next().await {
601 // Skip files larger than 1MB
602 if let Ok(Some(metadata)) = fs.metadata(&path).await {
603 if metadata.len > 1_000_000 {
604 continue;
605 }
606 }
607
608 // Attempt to read the file as text
609 if let Ok(content) = fs.load(&path).await {
610 let mut matches = Vec::new();
611
612 // Find all regex matches in the content
613 for capture in search_regex.find_iter(&content) {
614 matches.push(capture.as_str().to_string());
615 }
616
617 // If we found matches, create a result entry
618 if !matches.is_empty() {
619 let result_entry = lua.create_table()?;
620 result_entry.set("path", path.to_string_lossy().to_string())?;
621
622 let matches_table = lua.create_table()?;
623 for (ix, m) in matches.iter().enumerate() {
624 matches_table.set(ix + 1, m.clone())?;
625 }
626 result_entry.set("matches", matches_table)?;
627
628 search_results.push(result_entry);
629 }
630 }
631 }
632
633 // Create a table to hold our results
634 let results_table = lua.create_table()?;
635 for (ix, entry) in search_results.into_iter().enumerate() {
636 results_table.set(ix + 1, entry)?;
637 }
638
639 Ok(results_table)
640 }
641
642 async fn find_search_candidates(
643 search_query: SearchQuery,
644 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
645 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
646 Self::run_foreground_fn(
647 "finding search file candidates",
648 foreground_tx,
649 Box::new(move |session, mut cx| {
650 session.update(&mut cx, |session, cx| {
651 session.project.update(cx, |project, cx| {
652 project.worktree_store().update(cx, |worktree_store, cx| {
653 // TODO: Better limit? For now this is the same as
654 // MAX_SEARCH_RESULT_FILES.
655 let limit = 5000;
656 // TODO: Providing non-empty open_entries can make this a bit more
657 // efficient as it can skip checking that these paths are textual.
658 let open_entries = HashSet::default();
659 let candidates = worktree_store.find_search_candidates(
660 search_query,
661 limit,
662 open_entries,
663 project.fs().clone(),
664 cx,
665 );
666 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
667 cx.spawn(|worktree_store, cx| async move {
668 pin_mut!(candidates);
669
670 while let Some(project_path) = candidates.next().await {
671 worktree_store.read_with(&cx, |worktree_store, cx| {
672 if let Some(worktree) = worktree_store
673 .worktree_for_id(project_path.worktree_id, cx)
674 {
675 if let Some(abs_path) = worktree
676 .read(cx)
677 .absolutize(&project_path.path)
678 .log_err()
679 {
680 abs_paths_tx.unbounded_send(abs_path)?;
681 }
682 }
683 anyhow::Ok(())
684 })??;
685 }
686 anyhow::Ok(())
687 })
688 .detach();
689 abs_paths_rx
690 })
691 })
692 })
693 }),
694 )
695 .await?
696 }
697
698 async fn outline(
699 root_dir: Option<Arc<Path>>,
700 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
701 path_str: String,
702 ) -> anyhow::Result<String> {
703 let root_dir = root_dir
704 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
705 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
706 let outline = Self::run_foreground_fn(
707 "getting code outline",
708 foreground_tx,
709 Box::new(move |session, cx| {
710 cx.spawn(move |mut cx| async move {
711 // TODO: This will not use file content from `fs_changes`. It will also reflect
712 // user changes that have not been saved.
713 let buffer = session
714 .update(&mut cx, |session, cx| {
715 session
716 .project
717 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
718 })?
719 .await?;
720 buffer.update(&mut cx, |buffer, _cx| {
721 if let Some(outline) = buffer.snapshot().outline(None) {
722 Ok(outline)
723 } else {
724 Err(anyhow!("No outline for file {path_str}"))
725 }
726 })
727 })
728 }),
729 )
730 .await?
731 .await??;
732
733 Ok(outline
734 .items
735 .into_iter()
736 .map(|item| {
737 if item.text.contains('\n') {
738 log::error!("Outline item unexpectedly contains newline");
739 }
740 format!("{}{}", " ".repeat(item.depth), item.text)
741 })
742 .collect::<Vec<String>>()
743 .join("\n"))
744 }
745
746 async fn run_foreground_fn<R: Send + 'static>(
747 description: &str,
748 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
749 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
750 ) -> anyhow::Result<R> {
751 let (response_tx, response_rx) = oneshot::channel();
752 let send_result = foreground_tx
753 .send(ForegroundFn(Box::new(move |this, cx| {
754 response_tx.send(function(this, cx)).ok();
755 })))
756 .await;
757 match send_result {
758 Ok(()) => (),
759 Err(err) => {
760 return Err(anyhow::Error::new(err).context(format!(
761 "Internal error while enqueuing work for {description}"
762 )));
763 }
764 }
765 match response_rx.await {
766 Ok(result) => Ok(result),
767 Err(oneshot::Canceled) => Err(anyhow!(
768 "Internal error: response oneshot was canceled while {description}."
769 )),
770 }
771 }
772
773 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
774 let path = Path::new(&path_str);
775 if path.is_absolute() {
776 // Check if path starts with root_dir prefix without resolving symlinks
777 if path.starts_with(&root_dir) {
778 Ok(path.to_path_buf())
779 } else {
780 Err(anyhow!(
781 "Error: Absolute path {} is outside the current working directory",
782 path_str
783 ))
784 }
785 } else {
786 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
787 Ok(root_dir.join(path))
788 }
789 }
790}
791
792struct FileContent(RefCell<Vec<u8>>);
793
794impl UserData for FileContent {
795 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
796 // FileContent doesn't have any methods so far.
797 }
798}
799
800#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
801pub struct ScriptId(u32);
802
803pub struct Script {
804 pub state: ScriptState,
805}
806
807pub enum ScriptState {
808 Running {
809 stdout: Arc<Mutex<String>>,
810 },
811 Succeeded {
812 stdout: String,
813 },
814 Failed {
815 stdout: String,
816 error: anyhow::Error,
817 },
818}
819
820impl Script {
821 /// If exited, returns a message with the output for the LLM
822 pub fn output_message_for_llm(&self) -> Option<String> {
823 match &self.state {
824 ScriptState::Running { .. } => None,
825 ScriptState::Succeeded { stdout } => {
826 format!("Here's the script output:\n{}", stdout).into()
827 }
828 ScriptState::Failed { stdout, error } => format!(
829 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
830 error, stdout
831 )
832 .into(),
833 }
834 }
835
836 /// Get a snapshot of the script's stdout
837 pub fn stdout_snapshot(&self) -> String {
838 match &self.state {
839 ScriptState::Running { stdout } => stdout.lock().clone(),
840 ScriptState::Succeeded { stdout } => stdout.clone(),
841 ScriptState::Failed { stdout, .. } => stdout.clone(),
842 }
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use gpui::TestAppContext;
849 use project::FakeFs;
850 use serde_json::json;
851 use settings::SettingsStore;
852
853 use super::*;
854
855 #[gpui::test]
856 async fn test_print(cx: &mut TestAppContext) {
857 let script = r#"
858 print("Hello", "world!")
859 print("Goodbye", "moon!")
860 "#;
861
862 let output = test_script(script, cx).await.unwrap();
863 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
864 }
865
866 #[gpui::test]
867 async fn test_search(cx: &mut TestAppContext) {
868 let script = r#"
869 local results = search("world")
870 for i, result in ipairs(results) do
871 print("File: " .. result.path)
872 print("Matches:")
873 for j, match in ipairs(result.matches) do
874 print(" " .. match)
875 end
876 end
877 "#;
878
879 let output = test_script(script, cx).await.unwrap();
880 assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
881 }
882
883 async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
884 init_test(cx);
885 let fs = FakeFs::new(cx.executor());
886 fs.insert_tree(
887 "/",
888 json!({
889 "file1.txt": "Hello world!",
890 "file2.txt": "Goodbye moon!"
891 }),
892 )
893 .await;
894
895 let project = Project::test(fs, [Path::new("/")], cx).await;
896 let session = cx.new(|cx| ScriptSession::new(project, cx));
897
898 let (script_id, task) =
899 session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
900
901 task.await;
902
903 Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
904 }
905
906 fn init_test(cx: &mut TestAppContext) {
907 let settings_store = cx.update(SettingsStore::test);
908 cx.set_global(settings_store);
909 cx.update(Project::init_settings);
910 }
911}