1use anyhow::anyhow;
2use collections::{HashMap, HashSet};
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
9use parking_lot::Mutex;
10use project::{search::SearchQuery, Fs, Project};
11use regex::Regex;
12use std::{
13 cell::RefCell,
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17use util::{paths::PathMatcher, ResultExt};
18
19use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
20
21struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
22
23pub struct ScriptSession {
24 project: Entity<Project>,
25 // TODO Remove this
26 fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
27 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
28 _invoke_foreground_fns: Task<()>,
29 scripts: Vec<Script>,
30}
31
32impl ScriptSession {
33 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
34 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
35 ScriptSession {
36 project,
37 fs_changes: Arc::new(Mutex::new(HashMap::default())),
38 foreground_fns_tx,
39 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
40 while let Some(foreground_fn) = foreground_fns_rx.next().await {
41 foreground_fn.0(this.clone(), cx.clone());
42 }
43 }),
44 scripts: Vec::new(),
45 }
46 }
47
48 pub fn new_script(&mut self) -> ScriptId {
49 let id = ScriptId(self.scripts.len() as u32);
50 let script = Script {
51 id,
52 state: ScriptState::Generating,
53 source: SharedString::new_static(""),
54 };
55 self.scripts.push(script);
56 id
57 }
58
59 pub fn run_script(
60 &mut self,
61 script_id: ScriptId,
62 script_src: String,
63 cx: &mut Context<Self>,
64 ) -> Task<anyhow::Result<()>> {
65 let script = self.get_mut(script_id);
66
67 let stdout = Arc::new(Mutex::new(String::new()));
68 script.source = script_src.clone().into();
69 script.state = ScriptState::Running {
70 stdout: stdout.clone(),
71 };
72
73 let task = self.run_lua(script_src, stdout, cx);
74
75 cx.emit(ScriptEvent::Spawned(script_id));
76
77 cx.spawn(|session, mut cx| async move {
78 let result = task.await;
79
80 session.update(&mut cx, |session, cx| {
81 let script = session.get_mut(script_id);
82 let stdout = script.stdout_snapshot();
83
84 script.state = match result {
85 Ok(()) => ScriptState::Succeeded { stdout },
86 Err(error) => ScriptState::Failed { stdout, error },
87 };
88
89 cx.emit(ScriptEvent::Exited(script_id))
90 })
91 })
92 }
93
94 fn run_lua(
95 &mut self,
96 script: String,
97 stdout: Arc<Mutex<String>>,
98 cx: &mut Context<Self>,
99 ) -> Task<anyhow::Result<()>> {
100 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
101
102 // TODO Remove fs_changes
103 let fs_changes = self.fs_changes.clone();
104 // TODO Honor all worktrees instead of the first one
105 let root_dir = self
106 .project
107 .read(cx)
108 .visible_worktrees(cx)
109 .next()
110 .map(|worktree| worktree.read(cx).abs_path());
111
112 let fs = self.project.read(cx).fs().clone();
113 let foreground_fns_tx = self.foreground_fns_tx.clone();
114
115 let task = cx.background_spawn({
116 let stdout = stdout.clone();
117
118 async move {
119 let lua = Lua::new();
120 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
121 let globals = lua.globals();
122 globals.set(
123 "sb_print",
124 lua.create_function({
125 let stdout = stdout.clone();
126 move |_, args: MultiValue| Self::print(args, &stdout)
127 })?,
128 )?;
129 globals.set(
130 "search",
131 lua.create_async_function({
132 let foreground_fns_tx = foreground_fns_tx.clone();
133 move |lua, regex| {
134 let mut foreground_fns_tx = foreground_fns_tx.clone();
135 let fs = fs.clone();
136 async move {
137 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
138 .await
139 .into_lua_err()
140 }
141 }
142 })?,
143 )?;
144 globals.set(
145 "outline",
146 lua.create_async_function({
147 let root_dir = root_dir.clone();
148 move |_lua, path| {
149 let mut foreground_fns_tx = foreground_fns_tx.clone();
150 let root_dir = root_dir.clone();
151 async move {
152 Self::outline(root_dir, &mut foreground_fns_tx, path)
153 .await
154 .into_lua_err()
155 }
156 }
157 })?,
158 )?;
159 globals.set(
160 "sb_io_open",
161 lua.create_function({
162 let fs_changes = fs_changes.clone();
163 let root_dir = root_dir.clone();
164 move |lua, (path_str, mode)| {
165 Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
166 }
167 })?,
168 )?;
169 globals.set("user_script", script)?;
170
171 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
172
173 // Drop Lua instance to decrement reference count.
174 drop(lua);
175
176 anyhow::Ok(())
177 }
178 });
179
180 task
181 }
182
183 pub fn get(&self, script_id: ScriptId) -> &Script {
184 &self.scripts[script_id.0 as usize]
185 }
186
187 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
188 &mut self.scripts[script_id.0 as usize]
189 }
190
191 /// Sandboxed print() function in Lua.
192 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
193 for (index, arg) in args.into_iter().enumerate() {
194 // Lua's `print()` prints tab characters between each argument.
195 if index > 0 {
196 stdout.lock().push('\t');
197 }
198
199 // If the argument's to_string() fails, have the whole function call fail.
200 stdout.lock().push_str(&arg.to_string()?);
201 }
202 stdout.lock().push('\n');
203
204 Ok(())
205 }
206
207 /// Sandboxed io.open() function in Lua.
208 fn io_open(
209 lua: &Lua,
210 fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
211 root_dir: Option<&Arc<Path>>,
212 path_str: String,
213 mode: Option<String>,
214 ) -> mlua::Result<(Option<Table>, String)> {
215 let root_dir = root_dir
216 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
217
218 let mode = mode.unwrap_or_else(|| "r".to_string());
219
220 // Parse the mode string to determine read/write permissions
221 let read_perm = mode.contains('r');
222 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
223 let append = mode.contains('a');
224 let truncate = mode.contains('w');
225
226 // This will be the Lua value returned from the `open` function.
227 let file = lua.create_table()?;
228
229 // Store file metadata in the file
230 file.set("__path", path_str.clone())?;
231 file.set("__mode", mode.clone())?;
232 file.set("__read_perm", read_perm)?;
233 file.set("__write_perm", write_perm)?;
234
235 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
236 Ok(path) => path,
237 Err(err) => return Ok((None, format!("{err}"))),
238 };
239
240 // close method
241 let close_fn = {
242 let fs_changes = fs_changes.clone();
243 lua.create_function(move |_lua, file_userdata: mlua::Table| {
244 let write_perm = file_userdata.get::<bool>("__write_perm")?;
245 let path = file_userdata.get::<String>("__path")?;
246
247 if write_perm {
248 // When closing a writable file, record the content
249 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
250 let content_ref = content.borrow::<FileContent>()?;
251 let content_vec = content_ref.0.borrow();
252
253 // Don't actually write to disk; instead, just update fs_changes.
254 let path_buf = PathBuf::from(&path);
255 fs_changes
256 .lock()
257 .insert(path_buf.clone(), content_vec.clone());
258 }
259
260 Ok(true)
261 })?
262 };
263 file.set("close", close_fn)?;
264
265 // If it's a directory, give it a custom read() and return early.
266 if path.is_dir() {
267 // TODO handle the case where we changed it in the in-memory fs
268
269 // Create a special directory handle
270 file.set("__is_directory", true)?;
271
272 // Store directory entries
273 let entries = match std::fs::read_dir(&path) {
274 Ok(entries) => {
275 let mut entry_names = Vec::new();
276 for entry in entries.flatten() {
277 entry_names.push(entry.file_name().to_string_lossy().into_owned());
278 }
279 entry_names
280 }
281 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
282 };
283
284 // Save the list of entries
285 file.set("__dir_entries", entries)?;
286 file.set("__dir_position", 0usize)?;
287
288 // Create a directory-specific read function
289 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
290 let position = file_userdata.get::<usize>("__dir_position")?;
291 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
292
293 if position >= entries.len() {
294 return Ok(None); // No more entries
295 }
296
297 let entry = entries[position].clone();
298 file_userdata.set("__dir_position", position + 1)?;
299
300 Ok(Some(entry))
301 })?;
302 file.set("read", read_fn)?;
303
304 // If we got this far, the directory was opened successfully
305 return Ok((Some(file), String::new()));
306 }
307
308 let fs_changes_map = fs_changes.lock();
309
310 let is_in_changes = fs_changes_map.contains_key(&path);
311 let file_exists = is_in_changes || path.exists();
312 let mut file_content = Vec::new();
313
314 if file_exists && !truncate {
315 if is_in_changes {
316 file_content = fs_changes_map.get(&path).unwrap().clone();
317 } else {
318 // Try to read existing content if file exists and we're not truncating
319 match std::fs::read(&path) {
320 Ok(content) => file_content = content,
321 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
322 }
323 }
324 }
325
326 drop(fs_changes_map); // Unlock the fs_changes mutex.
327
328 // If in append mode, position should be at the end
329 let position = if append && file_exists {
330 file_content.len()
331 } else {
332 0
333 };
334 file.set("__position", position)?;
335 file.set(
336 "__content",
337 lua.create_userdata(FileContent(RefCell::new(file_content)))?,
338 )?;
339
340 // Create file methods
341
342 // read method
343 let read_fn = {
344 lua.create_function(
345 |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
346 let read_perm = file_userdata.get::<bool>("__read_perm")?;
347 if !read_perm {
348 return Err(mlua::Error::runtime("File not open for reading"));
349 }
350
351 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
352 let mut position = file_userdata.get::<usize>("__position")?;
353 let content_ref = content.borrow::<FileContent>()?;
354 let content_vec = content_ref.0.borrow();
355
356 if position >= content_vec.len() {
357 return Ok(None); // EOF
358 }
359
360 match format {
361 Some(mlua::Value::String(s)) => {
362 let lossy_string = s.to_string_lossy();
363 let format_str: &str = lossy_string.as_ref();
364
365 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
366 match &format_str[0..2] {
367 "*a" => {
368 // Read entire file from current position
369 let result = String::from_utf8_lossy(&content_vec[position..])
370 .to_string();
371 position = content_vec.len();
372 file_userdata.set("__position", position)?;
373 Ok(Some(result))
374 }
375 "*l" => {
376 // Read next line
377 let mut line = Vec::new();
378 let mut found_newline = false;
379
380 while position < content_vec.len() {
381 let byte = content_vec[position];
382 position += 1;
383
384 if byte == b'\n' {
385 found_newline = true;
386 break;
387 }
388
389 // Skip \r in \r\n sequence but add it if it's alone
390 if byte == b'\r' {
391 if position < content_vec.len()
392 && content_vec[position] == b'\n'
393 {
394 position += 1;
395 found_newline = true;
396 break;
397 }
398 }
399
400 line.push(byte);
401 }
402
403 file_userdata.set("__position", position)?;
404
405 if !found_newline
406 && line.is_empty()
407 && position >= content_vec.len()
408 {
409 return Ok(None); // EOF
410 }
411
412 let result = String::from_utf8_lossy(&line).to_string();
413 Ok(Some(result))
414 }
415 "*n" => {
416 // Try to parse as a number (number of bytes to read)
417 match format_str.parse::<usize>() {
418 Ok(n) => {
419 let end =
420 std::cmp::min(position + n, content_vec.len());
421 let bytes = &content_vec[position..end];
422 let result = String::from_utf8_lossy(bytes).to_string();
423 position = end;
424 file_userdata.set("__position", position)?;
425 Ok(Some(result))
426 }
427 Err(_) => Err(mlua::Error::runtime(format!(
428 "Invalid format: {}",
429 format_str
430 ))),
431 }
432 }
433 "*L" => {
434 // Read next line keeping the end of line
435 let mut line = Vec::new();
436
437 while position < content_vec.len() {
438 let byte = content_vec[position];
439 position += 1;
440
441 line.push(byte);
442
443 if byte == b'\n' {
444 break;
445 }
446
447 // If we encounter a \r, add it and check if the next is \n
448 if byte == b'\r'
449 && position < content_vec.len()
450 && content_vec[position] == b'\n'
451 {
452 line.push(content_vec[position]);
453 position += 1;
454 break;
455 }
456 }
457
458 file_userdata.set("__position", position)?;
459
460 if line.is_empty() && position >= content_vec.len() {
461 return Ok(None); // EOF
462 }
463
464 let result = String::from_utf8_lossy(&line).to_string();
465 Ok(Some(result))
466 }
467 _ => Err(mlua::Error::runtime(format!(
468 "Unsupported format: {}",
469 format_str
470 ))),
471 }
472 }
473 Some(mlua::Value::Number(n)) => {
474 // Read n bytes
475 let n = n as usize;
476 let end = std::cmp::min(position + n, content_vec.len());
477 let bytes = &content_vec[position..end];
478 let result = String::from_utf8_lossy(bytes).to_string();
479 position = end;
480 file_userdata.set("__position", position)?;
481 Ok(Some(result))
482 }
483 Some(_) => Err(mlua::Error::runtime("Invalid format")),
484 None => {
485 // Default is to read a line
486 let mut line = Vec::new();
487 let mut found_newline = false;
488
489 while position < content_vec.len() {
490 let byte = content_vec[position];
491 position += 1;
492
493 if byte == b'\n' {
494 found_newline = true;
495 break;
496 }
497
498 // Handle \r\n
499 if byte == b'\r' {
500 if position < content_vec.len()
501 && content_vec[position] == b'\n'
502 {
503 position += 1;
504 found_newline = true;
505 break;
506 }
507 }
508
509 line.push(byte);
510 }
511
512 file_userdata.set("__position", position)?;
513
514 if !found_newline && line.is_empty() && position >= content_vec.len() {
515 return Ok(None); // EOF
516 }
517
518 let result = String::from_utf8_lossy(&line).to_string();
519 Ok(Some(result))
520 }
521 }
522 },
523 )?
524 };
525 file.set("read", read_fn)?;
526
527 // write method
528 let write_fn = {
529 let fs_changes = fs_changes.clone();
530
531 lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
532 let write_perm = file_userdata.get::<bool>("__write_perm")?;
533 if !write_perm {
534 return Err(mlua::Error::runtime("File not open for writing"));
535 }
536
537 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
538 let position = file_userdata.get::<usize>("__position")?;
539 let content_ref = content.borrow::<FileContent>()?;
540 let mut content_vec = content_ref.0.borrow_mut();
541
542 let bytes = text.as_bytes();
543
544 // Ensure the vector has enough capacity
545 if position + bytes.len() > content_vec.len() {
546 content_vec.resize(position + bytes.len(), 0);
547 }
548
549 // Write the bytes
550 for (i, &byte) in bytes.iter().enumerate() {
551 content_vec[position + i] = byte;
552 }
553
554 // Update position
555 let new_position = position + bytes.len();
556 file_userdata.set("__position", new_position)?;
557
558 // Update fs_changes
559 let path = file_userdata.get::<String>("__path")?;
560 let path_buf = PathBuf::from(path);
561 fs_changes.lock().insert(path_buf, content_vec.clone());
562
563 Ok(true)
564 })?
565 };
566 file.set("write", write_fn)?;
567
568 // If we got this far, the file was opened successfully
569 Ok((Some(file), String::new()))
570 }
571
572 async fn search(
573 lua: &Lua,
574 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
575 fs: Arc<dyn Fs>,
576 regex: String,
577 ) -> anyhow::Result<Table> {
578 // TODO: Allow specification of these options.
579 let search_query = SearchQuery::regex(
580 ®ex,
581 false,
582 false,
583 false,
584 PathMatcher::default(),
585 PathMatcher::default(),
586 None,
587 );
588 let search_query = match search_query {
589 Ok(query) => query,
590 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
591 };
592
593 // TODO: Should use `search_query.regex`. The tool description should also be updated,
594 // as it specifies standard regex.
595 let search_regex = match Regex::new(®ex) {
596 Ok(re) => re,
597 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
598 };
599
600 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
601
602 let mut search_results: Vec<Table> = Vec::new();
603 while let Some(path) = abs_paths_rx.next().await {
604 // Skip files larger than 1MB
605 if let Ok(Some(metadata)) = fs.metadata(&path).await {
606 if metadata.len > 1_000_000 {
607 continue;
608 }
609 }
610
611 // Attempt to read the file as text
612 if let Ok(content) = fs.load(&path).await {
613 let mut matches = Vec::new();
614
615 // Find all regex matches in the content
616 for capture in search_regex.find_iter(&content) {
617 matches.push(capture.as_str().to_string());
618 }
619
620 // If we found matches, create a result entry
621 if !matches.is_empty() {
622 let result_entry = lua.create_table()?;
623 result_entry.set("path", path.to_string_lossy().to_string())?;
624
625 let matches_table = lua.create_table()?;
626 for (ix, m) in matches.iter().enumerate() {
627 matches_table.set(ix + 1, m.clone())?;
628 }
629 result_entry.set("matches", matches_table)?;
630
631 search_results.push(result_entry);
632 }
633 }
634 }
635
636 // Create a table to hold our results
637 let results_table = lua.create_table()?;
638 for (ix, entry) in search_results.into_iter().enumerate() {
639 results_table.set(ix + 1, entry)?;
640 }
641
642 Ok(results_table)
643 }
644
645 async fn find_search_candidates(
646 search_query: SearchQuery,
647 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
648 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
649 Self::run_foreground_fn(
650 "finding search file candidates",
651 foreground_tx,
652 Box::new(move |session, mut cx| {
653 session.update(&mut cx, |session, cx| {
654 session.project.update(cx, |project, cx| {
655 project.worktree_store().update(cx, |worktree_store, cx| {
656 // TODO: Better limit? For now this is the same as
657 // MAX_SEARCH_RESULT_FILES.
658 let limit = 5000;
659 // TODO: Providing non-empty open_entries can make this a bit more
660 // efficient as it can skip checking that these paths are textual.
661 let open_entries = HashSet::default();
662 let candidates = worktree_store.find_search_candidates(
663 search_query,
664 limit,
665 open_entries,
666 project.fs().clone(),
667 cx,
668 );
669 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
670 cx.spawn(|worktree_store, cx| async move {
671 pin_mut!(candidates);
672
673 while let Some(project_path) = candidates.next().await {
674 worktree_store.read_with(&cx, |worktree_store, cx| {
675 if let Some(worktree) = worktree_store
676 .worktree_for_id(project_path.worktree_id, cx)
677 {
678 if let Some(abs_path) = worktree
679 .read(cx)
680 .absolutize(&project_path.path)
681 .log_err()
682 {
683 abs_paths_tx.unbounded_send(abs_path)?;
684 }
685 }
686 anyhow::Ok(())
687 })??;
688 }
689 anyhow::Ok(())
690 })
691 .detach();
692 abs_paths_rx
693 })
694 })
695 })
696 }),
697 )
698 .await?
699 }
700
701 async fn outline(
702 root_dir: Option<Arc<Path>>,
703 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
704 path_str: String,
705 ) -> anyhow::Result<String> {
706 let root_dir = root_dir
707 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
708 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
709 let outline = Self::run_foreground_fn(
710 "getting code outline",
711 foreground_tx,
712 Box::new(move |session, cx| {
713 cx.spawn(move |mut cx| async move {
714 // TODO: This will not use file content from `fs_changes`. It will also reflect
715 // user changes that have not been saved.
716 let buffer = session
717 .update(&mut cx, |session, cx| {
718 session
719 .project
720 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
721 })?
722 .await?;
723 buffer.update(&mut cx, |buffer, _cx| {
724 if let Some(outline) = buffer.snapshot().outline(None) {
725 Ok(outline)
726 } else {
727 Err(anyhow!("No outline for file {path_str}"))
728 }
729 })
730 })
731 }),
732 )
733 .await?
734 .await??;
735
736 Ok(outline
737 .items
738 .into_iter()
739 .map(|item| {
740 if item.text.contains('\n') {
741 log::error!("Outline item unexpectedly contains newline");
742 }
743 format!("{}{}", " ".repeat(item.depth), item.text)
744 })
745 .collect::<Vec<String>>()
746 .join("\n"))
747 }
748
749 async fn run_foreground_fn<R: Send + 'static>(
750 description: &str,
751 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
752 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
753 ) -> anyhow::Result<R> {
754 let (response_tx, response_rx) = oneshot::channel();
755 let send_result = foreground_tx
756 .send(ForegroundFn(Box::new(move |this, cx| {
757 response_tx.send(function(this, cx)).ok();
758 })))
759 .await;
760 match send_result {
761 Ok(()) => (),
762 Err(err) => {
763 return Err(anyhow::Error::new(err).context(format!(
764 "Internal error while enqueuing work for {description}"
765 )));
766 }
767 }
768 match response_rx.await {
769 Ok(result) => Ok(result),
770 Err(oneshot::Canceled) => Err(anyhow!(
771 "Internal error: response oneshot was canceled while {description}."
772 )),
773 }
774 }
775
776 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
777 let path = Path::new(&path_str);
778 if path.is_absolute() {
779 // Check if path starts with root_dir prefix without resolving symlinks
780 if path.starts_with(&root_dir) {
781 Ok(path.to_path_buf())
782 } else {
783 Err(anyhow!(
784 "Error: Absolute path {} is outside the current working directory",
785 path_str
786 ))
787 }
788 } else {
789 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
790 Ok(root_dir.join(path))
791 }
792 }
793}
794
795struct FileContent(RefCell<Vec<u8>>);
796
797impl UserData for FileContent {
798 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
799 // FileContent doesn't have any methods so far.
800 }
801}
802
803#[derive(Debug)]
804pub enum ScriptEvent {
805 Spawned(ScriptId),
806 Exited(ScriptId),
807}
808
809impl EventEmitter<ScriptEvent> for ScriptSession {}
810
811#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
812pub struct ScriptId(u32);
813
814pub struct Script {
815 pub id: ScriptId,
816 pub state: ScriptState,
817 pub source: SharedString,
818}
819
820pub enum ScriptState {
821 Generating,
822 Running {
823 stdout: Arc<Mutex<String>>,
824 },
825 Succeeded {
826 stdout: String,
827 },
828 Failed {
829 stdout: String,
830 error: anyhow::Error,
831 },
832}
833
834impl Script {
835 pub fn source_tag(&self) -> String {
836 format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
837 }
838
839 /// If exited, returns a message with the output for the LLM
840 pub fn output_message_for_llm(&self) -> Option<String> {
841 match &self.state {
842 ScriptState::Generating { .. } => None,
843 ScriptState::Running { .. } => None,
844 ScriptState::Succeeded { stdout } => {
845 format!("Here's the script output:\n{}", stdout).into()
846 }
847 ScriptState::Failed { stdout, error } => format!(
848 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
849 error, stdout
850 )
851 .into(),
852 }
853 }
854
855 /// Get a snapshot of the script's stdout
856 pub fn stdout_snapshot(&self) -> String {
857 match &self.state {
858 ScriptState::Generating { .. } => String::new(),
859 ScriptState::Running { stdout } => stdout.lock().clone(),
860 ScriptState::Succeeded { stdout } => stdout.clone(),
861 ScriptState::Failed { stdout, .. } => stdout.clone(),
862 }
863 }
864
865 /// Returns the error if the script failed, otherwise None
866 pub fn error(&self) -> Option<&anyhow::Error> {
867 match &self.state {
868 ScriptState::Generating { .. } => None,
869 ScriptState::Running { .. } => None,
870 ScriptState::Succeeded { .. } => None,
871 ScriptState::Failed { error, .. } => Some(error),
872 }
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use gpui::TestAppContext;
879 use project::FakeFs;
880 use serde_json::json;
881 use settings::SettingsStore;
882
883 use super::*;
884
885 #[gpui::test]
886 async fn test_print(cx: &mut TestAppContext) {
887 let script = r#"
888 print("Hello", "world!")
889 print("Goodbye", "moon!")
890 "#;
891
892 let output = test_script(script, cx).await.unwrap();
893 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
894 }
895
896 #[gpui::test]
897 async fn test_search(cx: &mut TestAppContext) {
898 let script = r#"
899 local results = search("world")
900 for i, result in ipairs(results) do
901 print("File: " .. result.path)
902 print("Matches:")
903 for j, match in ipairs(result.matches) do
904 print(" " .. match)
905 end
906 end
907 "#;
908
909 let output = test_script(script, cx).await.unwrap();
910 assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
911 }
912
913 async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
914 init_test(cx);
915 let fs = FakeFs::new(cx.executor());
916 fs.insert_tree(
917 "/",
918 json!({
919 "file1.txt": "Hello world!",
920 "file2.txt": "Goodbye moon!"
921 }),
922 )
923 .await;
924
925 let project = Project::test(fs, [Path::new("/")], cx).await;
926 let session = cx.new(|cx| ScriptSession::new(project, cx));
927
928 let (script_id, task) = session.update(cx, |session, cx| {
929 let script_id = session.new_script();
930 let task = session.run_script(script_id, source.to_string(), cx);
931
932 (script_id, task)
933 });
934
935 task.await?;
936
937 Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
938 }
939
940 fn init_test(cx: &mut TestAppContext) {
941 let settings_store = cx.update(SettingsStore::test);
942 cx.set_global(settings_store);
943 cx.update(Project::init_settings);
944 }
945}